Unverified Commit 889efafd authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

[major] Adding some test workflows; support multiple batches

[major] Adding some test workflows; support multiple batches
parents 68dafdfa 16979769
<!-- Thank you for your contribution! We appreciate it. The following guidelines will help improve your pull request and facilitate feedback. If anything is unclear, don't hesitate to submit your pull request and ask the maintainers for assistance. -->
## Motivation
<!-- Explain the purpose of this PR and the goals it aims to achieve. -->
## Modifications
<!-- Describe the changes made in this PR. -->
## Checklist
- [ ] Code is formatted using Pre-Commit hooks.
- [ ] Relevant unit tests are added in the [`tests`](../tests) directory.
- [ ] [README](../README.md) and example scripts in [`examples`](../examples) are updated if necessary.
- [ ] Throughput/latency benchmarks and quality evaluations are included where applicable.
- [ ] **For reviewers:** If you're only helping merge the main branch and haven't contributed code to this PR, please remove yourself as a co-author when merging.
- [ ] Please feel free to join our [Slack](https://join.slack.com/t/nunchaku/shared_invite/zt-3170agzoz-NgZzWaTrEj~n2KEV3Hpl5Q), [Discord](https://discord.gg/Wk6PnwX9Sm) or [WeChat](https://github.com/mit-han-lab/nunchaku/blob/main/assets/wechat.jpg) to discuss your PR.
\ No newline at end of file
name: pr_test_ampere
on:
workflow_dispatch:
inputs:
test_target:
description: 'What to test: "pr" or "branch"'
required: true
type: choice
options:
- pr
- branch
pr_number:
description: 'Pull Request Number (only if test_target == "pr")'
required: false
branch_name:
description: 'Branch name (only if test_target == "branch")'
default: 'main'
required: false
# push:
# branches: [ main ]
# paths:
# - "nunchaku/**"
# - "src/**"
# - "tests/**"
# - "examples/**"
# pull_request:
# types: [ opened, synchronize, reopened, edited ]
# paths:
# - "nunchaku/**"
# - "src/**"
# - "tests/**"
# - "examples/**"
# issue_comment:
# types: [ created ]
concurrency:
group: ${{ github.repository }}-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
check-comment:
if: ${{ github.event_name == 'workflow_dispatch' || (github.event_name == 'issue_comment' && github.event.issue.pull_request && !github.event.pull_request.draft) }}
runs-on: [ self-hosted, ampere ]
outputs:
should_run: ${{ steps.check.outputs.should_run }}
steps:
- id: check
run: |
body="${{ github.event.comment.body }}"
body_lower=$(echo "$body" | tr '[:upper:]' '[:lower:]')
if [[ "$body_lower" == "run tests" || "$body_lower" == "run test" ]]; then
echo "should_run=true" >> $GITHUB_OUTPUT
else
echo "should_run=false" >> $GITHUB_OUTPUT
fi
set-up-build-env:
runs-on: [ self-hosted, ampere ]
needs: [ check-comment ]
if: ${{ github.event_name != 'issue_comment' || needs.check-comment.outputs.should_run == 'true' }}
steps:
- name: Determine ref
id: set-ref
run: |
if [[ "${{ github.event.inputs.test_target }}" == "pr" ]]; then
echo "ref=refs/pull/${{ github.event.inputs.pr_number }}/merge" >> $GITHUB_OUTPUT
else
echo "ref=refs/heads/${{ github.event.inputs.branch_name }}" >> $GITHUB_OUTPUT
fi
- name: Checkout
uses: actions/checkout@v4
with:
# ref: ${{ github.event.pull_request.head.sha || github.sha }}
ref: ${{ steps.set-ref.outputs.ref }}
submodules: true
- name: Show current commit
run: git log -1 --oneline
- name: Set up Python
run: |
which python
echo "Setting up Python with Conda"
conda create -n test_env python=3.11 -y
- name: Install dependencies
run: |
source $(conda info --base)/etc/profile.d/conda.sh
conda activate test_env || { echo "Failed to activate conda env"; exit 1; }
which python
conda install -c conda-forge gxx=11 gcc=11
echo "Installing dependencies"
pip install torch torchvision torchaudio
pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub
build:
needs: set-up-build-env
runs-on: [ self-hosted, ampere ]
timeout-minutes: 30
if: ${{ github.event_name != 'issue_comment' || needs.check-comment.outputs.should_run == 'true' }}
steps:
- name: Run build tests
run: |
source $(conda info --base)/etc/profile.d/conda.sh
conda activate test_env || { echo "Failed to activate conda env"; exit 1; }
which python
NUNCHAKU_INSTALL_MODE=ALL python setup.py develop
pip install -r tests/requirements.txt
test-flux-memory:
needs: build
runs-on: [ self-hosted, ampere ]
timeout-minutes: 30
if: ${{ github.event_name != 'issue_comment' || needs.check-comment.outputs.should_run == 'true' }}
steps:
- name: Run FLUX memory test
run: |
which python
source $(conda info --base)/etc/profile.d/conda.sh
conda activate test_env || { echo "Failed to activate conda env"; exit 1; }
which python
NUNCHAKU_TEST_CACHE_ROOT=${{ secrets.NUNCHAKU_TEST_CACHE_ROOT_AMPERE }} HF_TOKEN=${{ secrets.HF_TOKEN }} pytest -v tests/flux/test_flux_memory.py
test-flux-other:
needs: build
runs-on: [ self-hosted, ampere ]
timeout-minutes: 150
if: ${{ github.event_name != 'issue_comment' || needs.check-comment.outputs.should_run == 'true' }}
steps:
- name: Run other FLUX tests
run: |
which python
source $(conda info --base)/etc/profile.d/conda.sh
conda activate test_env || { echo "Failed to activate conda env"; exit 1; }
which python
NUNCHAKU_TEST_CACHE_ROOT=${{ secrets.NUNCHAKU_TEST_CACHE_ROOT_AMPERE }} HF_TOKEN=${{ secrets.HF_TOKEN }} pytest -v tests/flux --ignore=tests/flux/test_flux_memory.py
test-sana:
needs: build
runs-on: [ self-hosted, ampere ]
timeout-minutes: 60
if: ${{ github.event_name != 'issue_comment' || needs.check-comment.outputs.should_run == 'true' }}
steps:
- name: Run SANA tests
run: |
which python
source $(conda info --base)/etc/profile.d/conda.sh
conda activate test_env || { echo "Failed to activate conda env"; exit 1; }
which python
NUNCHAKU_TEST_CACHE_ROOT=${{ secrets.NUNCHAKU_TEST_CACHE_ROOT_AMPERE }} HF_TOKEN=${{ secrets.HF_TOKEN }} pytest -v tests/sana
clean-up:
if: always() && (github.event_name != 'issue_comment' || needs.check-comment.outputs.should_run == 'true')
needs: [ set-up-build-env, test-flux-memory, test-flux-other, test-sana ]
runs-on: [ self-hosted, ampere ]
steps:
- name: Clean up
run: |
cd ..
rm -rf *nunchaku*
name: pr_test_blackwell
on:
workflow_dispatch:
inputs:
test_target:
description: 'What to test: "pr" or "branch"'
required: true
type: choice
options:
- pr
- branch
pr_number:
description: 'Pull Request Number (only if test_target == "pr")'
required: false
branch_name:
description: 'Branch name (only if test_target == "branch")'
default: 'main'
required: false
# push:
# branches: [ main ]
# paths:
# - "nunchaku/**"
# - "src/**"
# - "tests/**"
# - "examples/**"
# pull_request:
# types: [ opened, synchronize, reopened, edited ]
# paths:
# - "nunchaku/**"
# - "src/**"
# - "tests/**"
# - "examples/**"
# issue_comment:
# types: [ created ]
concurrency:
group: ${{ github.repository }}-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
check-comment:
if: ${{ github.event_name == 'workflow_dispatch' || (github.event_name == 'issue_comment' && github.event.issue.pull_request && !github.event.pull_request.draft) }}
runs-on: [ self-hosted, blackwell ]
outputs:
should_run: ${{ steps.check.outputs.should_run }}
steps:
- id: check
run: |
body="${{ github.event.comment.body }}"
body_lower=$(echo "$body" | tr '[:upper:]' '[:lower:]')
if [[ "$body_lower" == "run tests" || "$body_lower" == "run test" ]]; then
echo "should_run=true" >> $GITHUB_OUTPUT
else
echo "should_run=false" >> $GITHUB_OUTPUT
fi
set-up-build-env:
runs-on: [ self-hosted, blackwell ]
needs: [ check-comment ]
if: ${{ github.event_name != 'issue_comment' || needs.check-comment.outputs.should_run == 'true' }}
steps:
- name: Determine ref
id: set-ref
run: |
if [[ "${{ github.event.inputs.test_target }}" == "pr" ]]; then
echo "ref=refs/pull/${{ github.event.inputs.pr_number }}/merge" >> $GITHUB_OUTPUT
else
echo "ref=refs/heads/${{ github.event.inputs.branch_name }}" >> $GITHUB_OUTPUT
fi
- name: Checkout
uses: actions/checkout@v4
with:
# ref: ${{ github.event.pull_request.head.sha || github.sha }}
ref: ${{ steps.set-ref.outputs.ref }}
submodules: true
- name: Show current commit
run: git log -1 --oneline
- name: Set up Python
run: |
which python
echo "Setting up Python with Conda"
conda create -n test_env python=3.11 -y
- name: Install dependencies
run: |
source $(conda info --base)/etc/profile.d/conda.sh
conda activate test_env || { echo "Failed to activate conda env"; exit 1; }
which python
conda install -c conda-forge gxx=11 gcc=11
echo "Installing dependencies"
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub
build:
needs: set-up-build-env
runs-on: [ self-hosted, blackwell ]
timeout-minutes: 30
if: ${{ github.event_name != 'issue_comment' || needs.check-comment.outputs.should_run == 'true' }}
steps:
- name: Run build tests
run: |
source $(conda info --base)/etc/profile.d/conda.sh
conda activate test_env || { echo "Failed to activate conda env"; exit 1; }
which python
NUNCHAKU_INSTALL_MODE=ALL python setup.py develop
pip install -r tests/requirements.txt
test-flux-memory:
needs: build
runs-on: [ self-hosted, blackwell ]
timeout-minutes: 30
if: ${{ github.event_name != 'issue_comment' || needs.check-comment.outputs.should_run == 'true' }}
steps:
- name: Run FLUX memory test
run: |
which python
source $(conda info --base)/etc/profile.d/conda.sh
conda activate test_env || { echo "Failed to activate conda env"; exit 1; }
which python
NUNCHAKU_TEST_CACHE_ROOT=${{ secrets.NUNCHAKU_TEST_CACHE_ROOT_BLACKWELL }} HF_TOKEN=${{ secrets.HF_TOKEN }} pytest -v tests/flux/test_flux_memory.py
test-flux-other:
needs: build
runs-on: [ self-hosted, blackwell ]
timeout-minutes: 150
if: ${{ github.event_name != 'issue_comment' || needs.check-comment.outputs.should_run == 'true' }}
steps:
- name: Run other FLUX tests
run: |
which python
source $(conda info --base)/etc/profile.d/conda.sh
conda activate test_env || { echo "Failed to activate conda env"; exit 1; }
which python
NUNCHAKU_TEST_CACHE_ROOT=${{ secrets.NUNCHAKU_TEST_CACHE_ROOT_BLACKWELL }} HF_TOKEN=${{ secrets.HF_TOKEN }} pytest -v tests/flux --ignore=tests/flux/test_flux_memory.py
test-sana:
needs: build
runs-on: [ self-hosted, blackwell ]
timeout-minutes: 60
if: ${{ github.event_name != 'issue_comment' || needs.check-comment.outputs.should_run == 'true' }}
steps:
- name: Run SANA tests
run: |
which python
source $(conda info --base)/etc/profile.d/conda.sh
conda activate test_env || { echo "Failed to activate conda env"; exit 1; }
which python
NUNCHAKU_TEST_CACHE_ROOT=${{ secrets.NUNCHAKU_TEST_CACHE_ROOT_BLACKWELL }} HF_TOKEN=${{ secrets.HF_TOKEN }} pytest -v tests/sana
clean-up:
if: always() && (github.event_name != 'issue_comment' || needs.check-comment.outputs.should_run == 'true')
needs: [ set-up-build-env, test-flux-memory, test-flux-other, test-sana ]
runs-on: [ self-hosted, blackwell ]
steps:
- name: Clean up
run: |
cd ..
rm -rf *nunchaku*
......@@ -129,7 +129,7 @@ If you're using a Blackwell GPU (e.g., 50-series GPUs), install a wheel with PyT
pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub
# For gradio demos
pip install peft opencv-python gradio spaces GPUtil
pip install peft opencv-python gradio spaces GPUtil
```
To enable NVFP4 on Blackwell GPUs (e.g., 50-series GPUs), please install nightly PyTorch with CUDA 12.8. The installation command can be:
......@@ -342,3 +342,7 @@ We thank MIT-IBM Watson AI Lab, MIT and Amazon Science Hub, MIT AI Hardware Prog
We use [img2img-turbo](https://github.com/GaParmar/img2img-turbo) to train the sketch-to-image LoRA. Our text-to-image and image-to-image UI is built upon [playground-v.25](https://huggingface.co/spaces/playgroundai/playground-v2.5/blob/main/app.py) and [img2img-turbo](https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py), respectively. Our safety checker is borrowed from [hart](https://github.com/mit-han-lab/hart).
Nunchaku is also inspired by many open-source libraries, including (but not limited to) [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [vLLM](https://github.com/vllm-project/vllm), [QServe](https://github.com/mit-han-lab/qserve), [AWQ](https://github.com/mit-han-lab/llm-awq), [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), and [Atom](https://github.com/efeslab/Atom).
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=mit-han-lab/nunchaku&type=Date)](https://www.star-history.com/#mit-han-lab/nunchaku&Date)
......@@ -27,26 +27,27 @@ RUN 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& apt update -y \
&& apt install software-properties-common -y \
&& add-apt-repository ppa:deadsnakes/ppa -y \
&& apt update
RUN apt install python${PYTHON_VERSION} python${PYTHON_VERSION}-dev g++-11 gcc-11 -y \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} && apt install python${PYTHON_VERSION}-distutils -y \
&& update-alternatives --install /usr/bin/python python /usr/bin/python${PYTHON_VERSION} 1 \
&& update-alternatives --set python /usr/bin/python${PYTHON_VERSION}
RUN apt install curl git sudo libibverbs-dev -y \
&& apt install -y rdma-core infiniband-diags openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 \
&& curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3 get-pip.py \
&& python3 --version \
&& python3 -m pip --version \
&& rm -rf /var/lib/apt/lists/* \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 1 && update-alternatives --set gcc /usr/bin/gcc-11 \
&& update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-11 1 && update-alternatives --set g++ /usr/bin/g++-11 \
&& apt clean
&& apt update && apt install wget git -y && apt clean
# Install Miniconda
ENV CONDA_DIR=/opt/conda
RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh \
&& bash /tmp/miniconda.sh -b -p ${CONDA_DIR} \
&& rm /tmp/miniconda.sh \
&& ${CONDA_DIR}/bin/conda clean -afy
ENV PATH=${CONDA_DIR}/bin:$PATH
RUN conda init bash
RUN conda create -y -n nunchaku python=${PYTHON_VERSION} \
&& conda install -y -n nunchaku -c conda-forge gxx=11 gcc=11 \
&& conda clean -afy
SHELL ["conda", "run", "-n", "nunchaku", "/bin/bash", "-c"]
# Install building dependencies
RUN pip install torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION} torchaudio==${TORCHAUDIO_VERSION} --index-url https://download.pytorch.org/whl/cu${CUDA_SHORT_VERSION}
RUN pip install torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION} torchaudio==${TORCHAUDIO_VERSION} --index-url https://download.pytorch.org/whl/cu124
RUN pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub comfy-cli
# Start building
......@@ -57,6 +58,7 @@ RUN git clone https://github.com/mit-han-lab/nunchaku.git \
&& NUNCHAKU_INSTALL_MODE=ALL python setup.py develop
RUN cd .. && git clone https://github.com/comfyanonymous/ComfyUI \
&& cd ComfyUI/custom_nodes && git clone https://github.com/ltdrdata/ComfyUI-Manager comfyui-manager \
&& git clone https://github.com/mit-han-lab/ComfyUI-nunchaku.git nunchaku_nodes \
&& cd .. && mkdir -p user/default/workflows/ && cp custom_nodes/nunchaku_nodes/workflows/* user/default/workflows/
&& cd ComfyUI && pip install -r requirements.txt \
&& cd custom_nodes && git clone https://github.com/ltdrdata/ComfyUI-Manager comfyui-manager \
&& git clone https://github.com/mit-han-lab/ComfyUI-nunchaku.git \
&& cd .. && mkdir -p user/default/workflows/ && cp -r custom_nodes/ComfyUI-nunchaku/workflows/ user/default/workflows/nunchaku_examples
# Use an NVIDIA base image with CUDA support
ARG CUDA_IMAGE="12.8.1-devel-ubuntu24.04"
FROM nvidia/cuda:${CUDA_IMAGE}
ENV DEBIAN_FRONTEND=noninteractive
ARG PYTHON_VERSION=3.11
ARG TORCH_VERSION=2.7
ARG TORCHVISION_VERSION=0.21
ARG TORCHAUDIO_VERSION=2.6
ARG CUDA_SHORT_VERSION=12.8
# Set working directory
WORKDIR /
RUN echo PYTHON_VERSION=${PYTHON_VERSION} \
&& echo CUDA_SHORT_VERSION=${CUDA_SHORT_VERSION} \
&& echo TORCH_VERSION=${TORCH_VERSION} \
&& echo TORCHVISION_VERSION=${TORCHVISION_VERSION} \
&& echo TORCHAUDIO_VERSION=${TORCHAUDIO_VERSION}
# Setup timezone and install system dependencies
RUN 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select New_York' | debconf-set-selections \
&& apt update -y \
&& apt install software-properties-common -y \
&& add-apt-repository ppa:deadsnakes/ppa -y \
&& apt update && apt install wget git -y && apt clean
# Install Miniconda
ENV CONDA_DIR=/opt/conda
RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh \
&& bash /tmp/miniconda.sh -b -p ${CONDA_DIR} \
&& rm /tmp/miniconda.sh \
&& ${CONDA_DIR}/bin/conda clean -afy
ENV PATH=${CONDA_DIR}/bin:$PATH
RUN conda init bash
RUN conda create -y -n nunchaku python=${PYTHON_VERSION} \
&& conda install -y -n nunchaku -c conda-forge gxx=11 gcc=11 \
&& conda clean -afy
SHELL ["conda", "run", "-n", "nunchaku", "/bin/bash", "-c"]
# Install building dependencies
RUN pip install --pre torch==2.7.0.dev20250307+cu128 torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
RUN pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub comfy-cli
# Start building
RUN git clone https://github.com/mit-han-lab/nunchaku.git \
&& cd nunchaku \
&& git submodule init \
&& git submodule update \
&& NUNCHAKU_INSTALL_MODE=ALL python setup.py develop
RUN cd .. && git clone https://github.com/comfyanonymous/ComfyUI \
&& cd ComfyUI && pip install -r requirements.txt \
&& cd custom_nodes && git clone https://github.com/ltdrdata/ComfyUI-Manager comfyui-manager \
&& git clone https://github.com/mit-han-lab/ComfyUI-nunchaku.git \
&& cd .. && mkdir -p user/default/workflows/ && cp -r custom_nodes/ComfyUI-nunchaku/workflows/ user/default/workflows/nunchaku_examples
# Use an NVIDIA base image with CUDA support
ARG CUDA_IMAGE="12.8.1-devel-ubuntu24.04"
FROM nvidia/cuda:${CUDA_IMAGE}
ENV DEBIAN_FRONTEND=noninteractive
ARG PYTHON_VERSION=3.11
ARG TORCH_VERSION=2.8
ARG TORCHVISION_VERSION=0.21
ARG TORCHAUDIO_VERSION=2.6
ARG CUDA_SHORT_VERSION=12.8
# Set working directory
WORKDIR /
RUN echo PYTHON_VERSION=${PYTHON_VERSION} \
&& echo CUDA_SHORT_VERSION=${CUDA_SHORT_VERSION} \
&& echo TORCH_VERSION=${TORCH_VERSION} \
&& echo TORCHVISION_VERSION=${TORCHVISION_VERSION} \
&& echo TORCHAUDIO_VERSION=${TORCHAUDIO_VERSION}
# Setup timezone and install system dependencies
RUN 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select New_York' | debconf-set-selections \
&& apt update -y \
&& apt install software-properties-common -y \
&& add-apt-repository ppa:deadsnakes/ppa -y \
&& apt update && apt install wget git -y && apt clean
# Install Miniconda
ENV CONDA_DIR=/opt/conda
RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh \
&& bash /tmp/miniconda.sh -b -p ${CONDA_DIR} \
&& rm /tmp/miniconda.sh \
&& ${CONDA_DIR}/bin/conda clean -afy
ENV PATH=${CONDA_DIR}/bin:$PATH
RUN conda init bash
RUN conda create -y -n nunchaku python=${PYTHON_VERSION} \
&& conda install -y -n nunchaku -c conda-forge gxx=11 gcc=11 \
&& conda clean -afy
SHELL ["conda", "run", "-n", "nunchaku", "/bin/bash", "-c"]
# Install building dependencies
RUN pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
RUN pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub comfy-cli
# Start building
RUN git clone https://github.com/mit-han-lab/nunchaku.git \
&& cd nunchaku \
&& git submodule init \
&& git submodule update \
&& NUNCHAKU_INSTALL_MODE=ALL python setup.py develop
RUN cd .. && git clone https://github.com/comfyanonymous/ComfyUI \
&& cd ComfyUI && pip install -r requirements.txt \
&& cd custom_nodes && git clone https://github.com/ltdrdata/ComfyUI-Manager comfyui-manager \
&& git clone https://github.com/mit-han-lab/ComfyUI-nunchaku.git \
&& cd .. && mkdir -p user/default/workflows/ && cp -r custom_nodes/ComfyUI-nunchaku/workflows/ user/default/workflows/nunchaku_examples
......@@ -5,7 +5,7 @@ from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters.flux import apply_cache_on_pipe
from nunchaku.utils import get_precision
from nunchaku.utils import get_gpu_memory, get_precision
base_model = "black-forest-labs/FLUX.1-dev"
controlnet_model_union = "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"
......@@ -14,14 +14,21 @@ controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union, t
controlnet = FluxMultiControlNetModel([controlnet_union]) # we always recommend loading via FluxMultiControlNetModel
precision = get_precision()
need_offload = get_gpu_memory() < 36
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-dev", torch_dtype=torch.bfloat16
f"mit-han-lab/svdq-{precision}-flux.1-dev", torch_dtype=torch.bfloat16, offload=need_offload
)
transformer.set_attention_impl("nunchaku-fp16")
pipeline = FluxControlNetPipeline.from_pretrained(
base_model, transformer=transformer, controlnet=controlnet, torch_dtype=torch.bfloat16
).to("cuda")
)
if need_offload:
pipeline.enable_sequential_cpu_offload()
else:
pipeline = pipeline.to("cuda")
# apply_cache_on_pipe(
# pipeline, residual_diff_threshold=0.1
# ) # Uncomment this line to enable first-block cache to speedup generation
......
__version__ = "0.2.0"
__version__ = "0.3.0dev0"
......@@ -9,6 +9,7 @@ from ...caching import utils
def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_diff_threshold=0.12):
if getattr(transformer, "_is_cached", False):
transformer.cached_transformer_blocks[0].update_threshold(residual_diff_threshold)
return transformer
cached_transformer_blocks = nn.ModuleList(
......
......@@ -245,6 +245,9 @@ class FluxCachedTransformerBlocks(nn.Module):
self.return_hidden_states_only = return_hidden_states_only
self.verbose = verbose
def update_residual_diff_threshold(self, residual_diff_threshold=0.12):
self.residual_diff_threshold = residual_diff_threshold
def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
batch_size = hidden_states.shape[0]
if self.residual_diff_threshold <= 0.0 or batch_size > 1:
......
......@@ -69,7 +69,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
controlnet_single_block_samples=None,
skip_first_layer=False,
):
batch_size = hidden_states.shape[0]
# batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1]
img_tokens = hidden_states.shape[1]
......@@ -95,9 +95,9 @@ class NunchakuFluxTransformerBlocks(nn.Module):
assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1
assert image_rotary_emb.shape[1] == 1
assert image_rotary_emb.shape[2] == batch_size * (txt_tokens + img_tokens)
# [bs, tokens, head_dim / 2, 1, 2] (sincos)
image_rotary_emb = image_rotary_emb.reshape([batch_size, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
assert image_rotary_emb.shape[2] == 1 * (txt_tokens + img_tokens)
# [1, tokens, head_dim / 2, 1, 2] (sincos)
image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
rotary_emb_single = image_rotary_emb # .to(self.dtype)
......@@ -135,7 +135,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
controlnet_block_samples=None,
controlnet_single_block_samples=None,
):
batch_size = hidden_states.shape[0]
# batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1]
img_tokens = hidden_states.shape[1]
......@@ -155,9 +155,9 @@ class NunchakuFluxTransformerBlocks(nn.Module):
assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1
assert image_rotary_emb.shape[1] == 1
assert image_rotary_emb.shape[2] == batch_size * (txt_tokens + img_tokens)
# [bs, tokens, head_dim / 2, 1, 2] (sincos)
image_rotary_emb = image_rotary_emb.reshape([batch_size, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
assert image_rotary_emb.shape[2] == 1 * (txt_tokens + img_tokens)
# [1, tokens, head_dim / 2, 1, 2] (sincos)
image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
......
......@@ -2,17 +2,17 @@ import torch
from diffusers import FluxPipeline
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision, is_turing
if __name__ == "__main__":
capability = torch.cuda.get_device_capability(0)
sm = f"{capability[0]}{capability[1]}"
precision = "fp4" if sm == "120" else "int4"
precision = get_precision()
torch_dtype = torch.float16 if is_turing() else torch.bfloat16
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", offload=True
f"mit-han-lab/svdq-{precision}-flux.1-schnell", torch_dtype=torch_dtype, offload=True
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch_dtype
)
pipeline.enable_sequential_cpu_offload()
image = pipeline(
......
......@@ -105,3 +105,26 @@ def is_turing(device: str | torch.device = "cuda") -> bool:
capability = torch.cuda.get_device_capability(device_id)
sm = f"{capability[0]}{capability[1]}"
return sm == "75"
def get_gpu_memory(device: str | torch.device = "cuda", unit: str = "GiB") -> int:
"""Get the GPU memory of the current device.
Args:
device (`str` | `torch.device`, optional, defaults to `"cuda"`):
device.
Returns:
`int`:
GPU memory in bytes.
"""
if isinstance(device, str):
device = torch.device(device)
assert unit in ("GiB", "MiB", "B")
memory = torch.cuda.get_device_properties(device).total_memory
if unit == "GiB":
return memory // (1024**3)
elif unit == "MiB":
return memory // (1024**2)
else:
return memory
#!/bin/bash
# Define the versions for Python, Torch, and CUDA
NUNCHAKU_VERSION=$1
python_versions=("3.10" "3.11" "3.12")
torch_versions=("2.5" "2.6")
cuda_versions=("12.4" "12.8")
# Loop through all combinations of Python, Torch, and CUDA versions
for python_version in "${python_versions[@]}"; do
for torch_version in "${torch_versions[@]}"; do
# Skip building for Python 3.13 and PyTorch 2.5
if [[ "$python_version" == "3.13" && "$torch_version" == "2.5" ]]; then
echo "Skipping Python 3.13 with PyTorch 2.5"
continue
fi
for cuda_version in "${cuda_versions[@]}"; do
bash scripts/build_docker.sh "$python_version" "$torch_version" "$cuda_version" "$NUNCHAKU_VERSION"
done
done
done
for python_version in "${python_versions[@]}"; do
for cuda_version in "${cuda_versions[@]}"; do
bash scripts/build_docker_torch27.sh "$python_version" "2.7" "$cuda_version" "$NUNCHAKU_VERSION"
bash scripts/build_docker_torch28.sh "$python_version" "2.8" "$cuda_version" "$NUNCHAKU_VERSION"
done
done
......@@ -29,12 +29,14 @@ else
exit 2
fi
docker build --no-cache \
docker build -f docker/Dockerfile --no-cache \
--build-arg PYTHON_VERSION=${PYTHON_VERSION} \
--build-arg CUDA_SHORT_VERSION=${CUDA_VERSION//.} \
--build-arg CUDA_IMAGE=${CUDA_IMAGE} \
--build-arg TORCH_VERSION=${TORCH_VERSION} \
--build-arg TORCHVISION_VERSION=${TORCHVISION_VERSION} \
--build-arg TORCHAUDIO_VERSION=${TORCHAUDIO_VERSION} \
-t nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} .
-t lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} .
docker push lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
docker rmi lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
\ No newline at end of file
#!/bin/bash
PYTHON_VERSION=$1
TORCH_VERSION=$2
CUDA_VERSION=$3
NUNCHAKU_VERSION=$4
TORCHVISION_VERSION=""
TORCHAUDIO_VERSION=""
if [ "$CUDA_VERSION" == "12.8" ]; then
CUDA_IMAGE="12.8.1-devel-ubuntu24.04"
echo "CUDA_VERSION is 12.8, setting CUDA_IMAGE to $CUDA_IMAGE"
elif [ "$CUDA_VERSION" == "12.4" ]; then
CUDA_IMAGE="12.4.1-devel-ubuntu22.04"
echo "CUDA_VERSION is 12.4, setting CUDA_IMAGE to $CUDA_IMAGE"
else
echo "CUDA_VERSION is not 12.8 or 12.4. Exit."
exit 2
fi
docker build -f docker/Dockerfile.torch27 --no-cache \
--build-arg PYTHON_VERSION=${PYTHON_VERSION} \
--build-arg CUDA_SHORT_VERSION=${CUDA_VERSION//.} \
--build-arg CUDA_IMAGE=${CUDA_IMAGE} \
--build-arg TORCH_VERSION=${TORCH_VERSION} \
--build-arg TORCHVISION_VERSION=${TORCHVISION_VERSION} \
--build-arg TORCHAUDIO_VERSION=${TORCHAUDIO_VERSION} \
-t lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} .
docker push lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
docker rmi lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
\ No newline at end of file
#!/bin/bash
PYTHON_VERSION=$1
TORCH_VERSION=$2
CUDA_VERSION=$3
NUNCHAKU_VERSION=$4
# Check if TORCH_VERSION is 2.5 or 2.6 and set the corresponding versions for TORCHVISION and TORCHAUDIO
if [ "$TORCH_VERSION" == "2.5" ]; then
TORCHVISION_VERSION="0.20"
TORCHAUDIO_VERSION="2.5"
echo "TORCH_VERSION is 2.5, setting TORCHVISION_VERSION to $TORCHVISION_VERSION and TORCHAUDIO_VERSION to $TORCHAUDIO_VERSION"
elif [ "$TORCH_VERSION" == "2.6" ]; then
TORCHVISION_VERSION="0.21"
TORCHAUDIO_VERSION="2.6"
echo "TORCH_VERSION is 2.6, setting TORCHVISION_VERSION to $TORCHVISION_VERSION and TORCHAUDIO_VERSION to $TORCHAUDIO_VERSION"
else
echo "TORCH_VERSION is not 2.5 or 2.6. Exit."
exit 2
fi
if [ "$CUDA_VERSION" == "12.8" ]; then
CUDA_IMAGE="12.8.1-devel-ubuntu24.04"
echo "CUDA_VERSION is 12.8, setting CUDA_IMAGE to $CUDA_IMAGE"
elif [ "$CUDA_VERSION" == "12.4" ]; then
CUDA_IMAGE="12.4.1-devel-ubuntu22.04"
echo "CUDA_VERSION is 12.4, setting CUDA_IMAGE to $CUDA_IMAGE"
else
echo "CUDA_VERSION is not 12.8 or 12.4. Exit."
exit 2
fi
docker build -f docker/Dockerfile.torch28 --no-cache \
--build-arg PYTHON_VERSION=${PYTHON_VERSION} \
--build-arg CUDA_SHORT_VERSION=${CUDA_VERSION//.} \
--build-arg CUDA_IMAGE=${CUDA_IMAGE} \
--build-arg TORCH_VERSION=${TORCH_VERSION} \
--build-arg TORCHVISION_VERSION=${TORCHVISION_VERSION} \
--build-arg TORCHAUDIO_VERSION=${TORCHAUDIO_VERSION} \
-t lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} .
docker push lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
docker rmi lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
\ No newline at end of file
......@@ -60,7 +60,8 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor
Tensor norm_x = norm.forward(x);
debug("norm_x", norm_x);
kernels::mul_add(norm_x, scale_msa, shift_msa);
// kernels::mul_add(norm_x, scale_msa, shift_msa);
kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
return Output{norm_x, gate_msa};
}
......@@ -89,7 +90,8 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
Tensor norm_x = norm.forward(x);
debug("norm_x", norm_x);
kernels::mul_add(norm_x, scale_msa, shift_msa);
// kernels::mul_add(norm_x, scale_msa, shift_msa);
kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
debug("norm_x_scaled", norm_x);
return Output{norm_x};
......@@ -100,7 +102,8 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
Tensor norm_x = norm.forward(x);
debug("norm_x", norm_x);
kernels::mul_add(norm_x, scale_msa, shift_msa);
// kernels::mul_add(norm_x, scale_msa, shift_msa);
kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
debug("norm_x_scaled", norm_x);
return Output{norm_x, gate_msa, shift_mlp, scale_mlp, gate_mlp};
......@@ -335,7 +338,9 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
// qkv_proj.forward(norm_hidden_states, qkv, {});
// debug("qkv_raw", qkv);
qkv_proj.forward(norm_hidden_states, qkv, {}, norm_q.weight, norm_k.weight, rotary_emb);
for (int i = 0; i < batch_size; i++) {
qkv_proj.forward(norm_hidden_states.slice(0, i, i+1), qkv.slice(0, i, i+1), {}, norm_q.weight, norm_k.weight, rotary_emb);
}
debug("qkv", qkv);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
......@@ -343,7 +348,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
attn_output = attn.forward(qkv);
attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head});
} else if (attnImpl == AttentionImpl::NunchakuFP16) {
assert(batch_size == 1);
// assert(batch_size == 1);
const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256;
......@@ -351,7 +356,14 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
Tensor k = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
Tensor v = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
qkv_proj.forward(norm_hidden_states, {}, {}, norm_q.weight, norm_k.weight, rotary_emb, q, k, v, num_tokens);
for (int i = 0; i < batch_size; i++) {
qkv_proj.forward(
norm_hidden_states.slice(0, i, i+1), {}, {}, norm_q.weight, norm_k.weight, rotary_emb,
q.slice(0, i, i+1),
k.slice(0, i, i+1),
v.slice(0, i, i+1),
num_tokens);
}
debug("packed_q", q);
debug("packed_k", k);
......@@ -361,7 +373,21 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
kernels::attention_fp16(q, k, v, o, pow(dim_head, (-0.5)));
attn_output = o.slice(1, 0, num_tokens);
if (batch_size == 1 || num_tokens_pad == num_tokens) {
attn_output = o.slice(1, 0, num_tokens);
} else {
attn_output = Tensor::allocate({batch_size, num_tokens, num_heads * dim_head}, o.scalar_type(), o.device());
checkCUDA(cudaMemcpy2DAsync(
attn_output.data_ptr(),
attn_output.stride(0) * attn_output.scalar_size(),
o.data_ptr(),
o.stride(0) * o.scalar_size(),
attn_output.stride(0) * attn_output.scalar_size(),
batch_size,
cudaMemcpyDeviceToDevice,
getCurrentCUDAStream()
));
}
} else {
assert(false);
}
......@@ -379,7 +405,8 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
hidden_states = kernels::add(attn_output, ff_output);
debug("attn_ff_output", hidden_states);
kernels::mul_add(hidden_states, gate, residual);
// kernels::mul_add(hidden_states, gate, residual);
kernels::mul_add_batch(hidden_states, gate, true, 0.0, residual, true);
nvtxRangePop();
......@@ -627,7 +654,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug("img.attn_output", attn_output);
#if 1
kernels::mul_add(attn_output, gate_msa, hidden_states);
// kernels::mul_add(attn_output, gate_msa, hidden_states);
kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, hidden_states, true);
hidden_states = std::move(attn_output);
nvtxRangePop();
......@@ -638,7 +666,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor norm_hidden_states = norm2.forward(hidden_states);
debug("scale_mlp", scale_mlp);
debug("shift_mlp", shift_mlp);
kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
// kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
......@@ -651,7 +680,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug("img.ff_output", ff_output);
debug("gate_mlp", gate_mlp);
kernels::mul_add(ff_output, gate_mlp, hidden_states);
// kernels::mul_add(ff_output, gate_mlp, hidden_states);
kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, hidden_states, true);
hidden_states = std::move(ff_output);
nvtxRangePop();
......@@ -692,7 +722,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug("context.attn_output", attn_output);
#if 1
kernels::mul_add(attn_output, gate_msa, encoder_hidden_states);
// kernels::mul_add(attn_output, gate_msa, encoder_hidden_states);
kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, encoder_hidden_states, true);
encoder_hidden_states = std::move(attn_output);
nvtxRangePop();
......@@ -703,7 +734,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor norm_hidden_states = norm2_context.forward(encoder_hidden_states);
debug("c_scale_mlp", scale_mlp);
debug("c_shift_mlp", shift_mlp);
kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
// kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
......@@ -718,7 +750,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug("context.ff_output", ff_output);
debug("c_gate_mlp", gate_mlp);
kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states);
// kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states);
kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, encoder_hidden_states, true);
encoder_hidden_states = std::move(ff_output);
nvtxRangePop();
......@@ -741,7 +774,7 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic
}
}
for (int i = 0; i < 38; i++) {
single_transformer_blocks.push_back(std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, Device::cuda()));
single_transformer_blocks.push_back(std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, device));
registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i));
if (offload) {
single_transformer_blocks.back()->setLazyLoad(true);
......@@ -791,8 +824,8 @@ Tensor FluxModel::forward(
// txt first, same as diffusers
concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device);
for (int i = 0; i < batch_size; i++) {
concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states);
concat.slice(0, i, i + 1).slice(1, txt_tokens, txt_tokens + img_tokens).copy_(hidden_states);
concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states.slice(0, i, i + 1));
concat.slice(0, i, i + 1).slice(1, txt_tokens, txt_tokens + img_tokens).copy_(hidden_states.slice(0, i, i + 1));
}
hidden_states = concat;
encoder_hidden_states = {};
......
......@@ -73,8 +73,9 @@ Tensor GEMV_AWQ::forward(Tensor x) {
Tensor out = gemv_awq(x, this->qweight, this->wscales, this->wzeros, M, out_features, in_features, group_size);
if (bias.valid()) {
// TODO: batch
assert(out.numel() == bias.numel());
out = kernels::add(out, bias.view(out.shape.dataExtent));
// assert(out.numel() == bias.numel());
// out = kernels::add(out, bias.view(out.shape.dataExtent));
kernels::mul_add_batch(out, {}, false, 0.0, bias, false);
}
debug("out_before_lora", out);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment