Commit 7e1d5e53 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.3.1

parents e3378b20 5f08050d
...@@ -6,15 +6,16 @@ set -o pipefail ...@@ -6,15 +6,16 @@ set -o pipefail
# cd into parent directory of this file # cd into parent directory of this file
cd "$(dirname "${BASH_SOURCE[0]}")/.." cd "$(dirname "${BASH_SOURCE[0]}")/.."
(wget && curl) || (apt-get update && apt-get install -y wget curl) (which wget && which curl) || (apt-get update && apt-get install -y wget curl)
# run benchmarks and upload the result to buildkite # run python-based benchmarks and upload the result to buildkite
python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt
bench_latency_exit_code=$? bench_latency_exit_code=$?
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt
bench_throughput_exit_code=$? bench_throughput_exit_code=$?
# run server-based benchmarks and upload the result to buildkite
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf & python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf &
server_pid=$! server_pid=$!
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
...@@ -22,11 +23,14 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r ...@@ -22,11 +23,14 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r
# wait for server to start, timeout after 600 seconds # wait for server to start, timeout after 600 seconds
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
python3 benchmarks/benchmark_serving.py \ python3 benchmarks/benchmark_serving.py \
--backend openai \
--dataset ./ShareGPT_V3_unfiltered_cleaned_split.json \ --dataset ./ShareGPT_V3_unfiltered_cleaned_split.json \
--model meta-llama/Llama-2-7b-chat-hf \ --model meta-llama/Llama-2-7b-chat-hf \
--num-prompts 20 \ --num-prompts 20 \
--endpoint /v1/completions \ --endpoint /v1/completions \
--tokenizer meta-llama/Llama-2-7b-chat-hf 2>&1 | tee benchmark_serving.txt --tokenizer meta-llama/Llama-2-7b-chat-hf \
--save-result \
2>&1 | tee benchmark_serving.txt
bench_serving_exit_code=$? bench_serving_exit_code=$?
kill $server_pid kill $server_pid
...@@ -44,7 +48,7 @@ sed -n '$p' benchmark_throughput.txt >> benchmark_results.md # last line ...@@ -44,7 +48,7 @@ sed -n '$p' benchmark_throughput.txt >> benchmark_results.md # last line
echo "### Serving Benchmarks" >> benchmark_results.md echo "### Serving Benchmarks" >> benchmark_results.md
sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line
echo "" >> benchmark_results.md echo "" >> benchmark_results.md
tail -n 5 benchmark_serving.txt >> benchmark_results.md # last 5 lines tail -n 13 benchmark_serving.txt >> benchmark_results.md # last 13 lines
# upload the results to buildkite # upload the results to buildkite
/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md /workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
...@@ -61,3 +65,5 @@ fi ...@@ -61,3 +65,5 @@ fi
if [ $bench_serving_exit_code -ne 0 ]; then if [ $bench_serving_exit_code -ne 0 ]; then
exit $bench_serving_exit_code exit $bench_serving_exit_code
fi fi
/workspace/buildkite-agent artifact upload openai-*.json
...@@ -49,3 +49,10 @@ steps: ...@@ -49,3 +49,10 @@ steps:
commands: commands:
- pip install aiohttp - pip install aiohttp
- bash run-benchmarks.sh - bash run-benchmarks.sh
- label: Documentation Build
working_dir: "/vllm-workspace/docs"
no_gpu: True
commands:
- pip install -r requirements-docs.txt
- SPHINXOPTS=\"-W\" make html
...@@ -35,13 +35,15 @@ steps: ...@@ -35,13 +35,15 @@ steps:
- image: "{{ docker_image }}" - image: "{{ docker_image }}"
command: ["bash"] command: ["bash"]
args: args:
- "-c" - '-c'
- "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'" - "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'"
{% if not step.no_gpu %}
resources: resources:
requests: requests:
nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}" nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
limits: limits:
nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}" nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
{% endif %}
env: env:
- name: HF_TOKEN - name: HF_TOKEN
valueFrom: valueFrom:
......
...@@ -7,6 +7,12 @@ FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev ...@@ -7,6 +7,12 @@ FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
RUN apt-get update -y \ RUN apt-get update -y \
&& apt-get install -y python3-pip git && apt-get install -y python3-pip git
# Workaround for https://github.com/openai/triton/issues/2507 and
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
# this won't be needed for future versions of this docker image
# or future versions of triton.
RUN ldconfig /usr/local/cuda-12.1/compat/
WORKDIR /workspace WORKDIR /workspace
# install build and runtime dependencies # install build and runtime dependencies
...@@ -69,8 +75,10 @@ RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip instal ...@@ -69,8 +75,10 @@ RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip instal
#################### RUNTIME BASE IMAGE #################### #################### RUNTIME BASE IMAGE ####################
# use CUDA base as CUDA runtime dependencies are already installed via pip # We used base cuda image because pytorch installs its own cuda libraries.
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base # However cupy depends on cuda libraries so we had to switch to the runtime image
# In the future it would be nice to get a container with pytorch and cuda without duplicating cuda
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base
# libnccl required for ray # libnccl required for ray
RUN apt-get update -y \ RUN apt-get update -y \
......
...@@ -10,9 +10,6 @@ RUN echo "Base image is $BASE_IMAGE" ...@@ -10,9 +10,6 @@ RUN echo "Base image is $BASE_IMAGE"
# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" # BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" # BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
# this does not always work for all rocm versions
RUN LLVM_GFX_ARCH=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) && \
echo "LLVM_GFX_ARCH is $LLVM_GFX_ARCH"
ARG FA_GFX_ARCHS="gfx90a;gfx942" ARG FA_GFX_ARCHS="gfx90a;gfx942"
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS" RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
...@@ -20,6 +17,12 @@ RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS" ...@@ -20,6 +17,12 @@ RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
ARG FA_BRANCH="3d2b6f5" ARG FA_BRANCH="3d2b6f5"
RUN echo "FA_BRANCH is $FA_BRANCH" RUN echo "FA_BRANCH is $FA_BRANCH"
# whether to build flash-attention
# if 0, will not build flash attention
# this is useful for gfx target where flash-attention is not supported
# In that case, we need to use the python reference attention implementation in vllm
ARG BUILD_FA="1"
# Install some basic utilities # Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y RUN apt-get update && apt-get install python3 python3-pip -y
...@@ -53,9 +56,10 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib: ...@@ -53,9 +56,10 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/: ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
# Install ROCm flash-attention # Install ROCm flash-attention
RUN mkdir libs \ RUN if [ "$BUILD_FA" = "1" ]; then \
mkdir libs \
&& cd libs \ && cd libs \
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \ && git clone https://github.com/ROCm/flash-attention.git \
&& cd flash-attention \ && cd flash-attention \
&& git checkout ${FA_BRANCH} \ && git checkout ${FA_BRANCH} \
&& git submodule update --init \ && git submodule update --init \
...@@ -63,7 +67,8 @@ RUN mkdir libs \ ...@@ -63,7 +67,8 @@ RUN mkdir libs \
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \ && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \ patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
&& python3 setup.py install \ && python3 setup.py install \
&& cd .. && cd ..; \
fi
COPY ./ /app/vllm COPY ./ /app/vllm
...@@ -78,7 +83,9 @@ RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ...@@ -78,7 +83,9 @@ RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
RUN cd /app \ RUN cd /app \
&& cd vllm \ && cd vllm \
&& pip install -U -r requirements-rocm.txt \ && pip install -U -r requirements-rocm.txt \
&& bash patch_xformers.rocm.sh \ && if [ "$BUILD_FA" = "1" ]; then \
bash patch_xformers.rocm.sh; fi \
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
&& python3 setup.py install \ && python3 setup.py install \
&& cd .. && cd ..
......
# <div align="center"><strong>vLLM</strong></div> # <div align="center"><strong>vLLM</strong></div>
## 简介 ## 简介
vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention高效管理kv内存,Continuous batching传入请求,支持很多Hugging Face模型,如LLaMA & LLaMA-2、Qwen、Chatglm2 & Chatglm23等。 vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention高效管理kv内存,Continuous batching传入请求,支持很多Hugging Face模型,如LLaMA & LLaMA-2、Qwen、Chatglm2 & Chatglm3等。
## 安装 ## 安装
vLLM支持 vLLM支持
...@@ -41,7 +41,7 @@ python3 setup.py install ...@@ -41,7 +41,7 @@ python3 setup.py install
+ 若使用 pip install 下载安装过慢,可添加源:-i https://pypi.tuna.tsinghua.edu.cn/simple/ + 若使用 pip install 下载安装过慢,可添加源:-i https://pypi.tuna.tsinghua.edu.cn/simple/
## 验证 ## 验证
- python -c "import vllm; print(vllm.\_\_version__)",版本号与官方版本同步,查询该软件的版本号,例如0.3.0 - python -c "import vllm; print(vllm.\_\_version__)",版本号与官方版本同步,查询该软件的版本号,例如0.3.1
## Known Issue ## Known Issue
- -
......
...@@ -16,16 +16,8 @@ Easy, fast, and cheap LLM serving for everyone ...@@ -16,16 +16,8 @@ Easy, fast, and cheap LLM serving for everyone
--- ---
**The Second vLLM Bay Area Meetup (Jan 31st 5pm-7:30pm PT)**
We are thrilled to announce our second vLLM Meetup!
The vLLM team will share recent updates and roadmap.
We will also have vLLM collaborators from IBM coming up to the stage to discuss their insights on LLM optimizations.
Please register [here](https://lu.ma/ygxbpzhl) and join us!
---
*Latest News* 🔥 *Latest News* 🔥
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
- [2024/01] Added ROCm 6.0 support to vLLM. - [2024/01] Added ROCm 6.0 support to vLLM.
- [2023/12] Added ROCm 5.7 support to vLLM. - [2023/12] Added ROCm 5.7 support to vLLM.
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing). - [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
...@@ -73,6 +65,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi ...@@ -73,6 +65,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.) - GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) - GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.) - InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) - LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) - Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.) - Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
......
import json
import os
import time
from dataclasses import dataclass
from typing import Optional
import aiohttp
from tqdm.asyncio import tqdm
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
@dataclass
class RequestFuncInput:
prompt: str
api_url: str
prompt_len: int
output_len: int
model: str
best_of: int = 1
use_beam_search: bool = False
@dataclass
class RequestFuncOutput:
generated_text: str = ""
success: bool = False
latency: float = 0
ttft: float = 0
prompt_len: int = 0
async def async_request_tgi(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
params = {
"best_of": request_func_input.best_of,
"max_new_tokens": request_func_input.output_len,
"do_sample": True,
"temperature": 0.01, # TGI does not accept 0.0 temperature.
"top_p": 0.99, # TGI does not accept 1.0 top_p.
}
payload = {
"inputs": request_func_input.prompt,
"parameters": params,
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0
st = time.perf_counter()
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for data in response.content.iter_any():
if ttft == 0:
ttft = time.perf_counter() - st
output.ttft = ttft
output.latency = time.perf_counter() - st
body = data.decode("utf-8").lstrip("data:")
output.generated_text = json.loads(body)["generated_text"]
output.success = True
else:
output.success = False
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
output.success = False
if pbar:
pbar.update(1)
return output
async def async_request_vllm(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("generate")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"prompt": request_func_input.prompt,
"n": 1,
"best_of": request_func_input.best_of,
"use_beam_search": request_func_input.use_beam_search,
"temperature": 0.0 if request_func_input.use_beam_search else 1.0,
"top_p": 1.0,
"max_tokens": request_func_input.output_len,
"ignore_eos": True,
"stream": True,
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0
st = time.perf_counter()
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for data in response.content.iter_any():
if ttft == 0:
ttft = time.perf_counter() - st
output.ttft = ttft
output.latency = time.perf_counter() - st
# When streaming, '\0' is appended to the end of the response.
body = data.decode("utf-8").strip("\0")
output.generated_text = json.loads(
body)["text"][0][len(request_func_input.prompt):]
output.success = True
else:
output.success = False
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
output.success = False
if pbar:
pbar.update(1)
return output
async def async_request_trt_llm(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
assert request_func_input.best_of == 1
payload = {
"accumulate_tokens": True,
"text_input": request_func_input.prompt,
"temperature": 0.0,
"top_p": 1.0,
"max_tokens": request_func_input.output_len,
"stream": True,
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0
st = time.perf_counter()
try:
async with session.post(url=api_url, json=payload) as resp:
if resp.status == 200:
async for data in resp.content.iter_any():
if ttft == 0:
ttft = time.perf_counter() - st
output.ttft = ttft
output.latency = time.perf_counter() - st
body = data.decode("utf-8").lstrip("data:")
output.generated_text = json.loads(body)["text_output"]
output.success = True
else:
output.success = False
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
output.success = False
if pbar:
pbar.update(1)
return output
async def async_request_deepspeed_mii(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert request_func_input.best_of == 1
assert not request_func_input.use_beam_search
payload = {
"prompts": request_func_input.prompt,
"max_new_tokens": request_func_input.output_len,
"ignore_eos": True,
"do_sample": True,
"temperature":
0.01, # deepspeed-mii does not accept 0.0 temperature.
"top_p": 1.0,
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
# DeepSpeed-MII doesn't support streaming as of Jan 28 2024, will use 0 as placeholder.
# https://github.com/microsoft/DeepSpeed-MII/pull/311
output.ttft = 0
st = time.perf_counter()
try:
async with session.post(url=request_func_input.api_url,
json=payload) as resp:
if resp.status == 200:
parsed_resp = await resp.json()
output.latency = time.perf_counter() - st
output.generated_text = parsed_resp[0]["generated_text"]
output.success = True
else:
output.success = False
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
output.success = False
if pbar:
pbar.update(1)
return output
async def async_request_openai_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("v1/completions")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
payload = {
"model": request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
"best_of": request_func_input.best_of,
"max_tokens": request_func_input.output_len,
"stream": True,
}
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0
st = time.perf_counter()
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
async for chunk in response.content:
if ttft == 0:
ttft = time.perf_counter() - st
output.ttft = ttft
chunk = chunk.strip()
if not chunk:
continue
chunk = chunk.decode("utf-8").lstrip("data: ")
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
body = json.loads(chunk)
generated_text += body["choices"][0]["text"]
output.generated_text = generated_text
output.success = True
output.latency = latency
else:
output.success = False
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
output.success = False
if pbar:
pbar.update(1)
return output
ASYNC_REQUEST_FUNCS = {
"tgi": async_request_tgi,
"vllm": async_request_vllm,
"deepspeed-mii": async_request_deepspeed_mii,
"openai": async_request_openai_completions,
"tensorrt-llm": async_request_trt_llm,
}
...@@ -25,6 +25,7 @@ def main(args: argparse.Namespace): ...@@ -25,6 +25,7 @@ def main(args: argparse.Namespace):
dtype=args.dtype, dtype=args.dtype,
enforce_eager=args.enforce_eager, enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype, kv_cache_dtype=args.kv_cache_dtype,
device=args.device,
) )
sampling_params = SamplingParams( sampling_params = SamplingParams(
...@@ -36,7 +37,10 @@ def main(args: argparse.Namespace): ...@@ -36,7 +37,10 @@ def main(args: argparse.Namespace):
max_tokens=args.output_len, max_tokens=args.output_len,
) )
print(sampling_params) print(sampling_params)
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()
def run_to_completion(profile_dir: Optional[str] = None): def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir: if profile_dir:
...@@ -70,7 +74,7 @@ def main(args: argparse.Namespace): ...@@ -70,7 +74,7 @@ def main(args: argparse.Namespace):
"." "."
) / "vllm_benchmark_result" / f"latency_result_{time.time()}" ) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
print(f"Profiling (results will be saved to '{profile_dir}')...") print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=args.profile_result_dir) run_to_completion(profile_dir=profile_dir)
return return
# Benchmark. # Benchmark.
...@@ -135,5 +139,11 @@ if __name__ == '__main__': ...@@ -135,5 +139,11 @@ if __name__ == '__main__':
default=None, default=None,
help=('path to save the pytorch profiler output. Can be visualized ' help=('path to save the pytorch profiler output. Can be visualized '
'with ui.perfetto.dev or Tensorboard.')) 'with ui.perfetto.dev or Tensorboard.'))
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -20,16 +20,36 @@ import asyncio ...@@ -20,16 +20,36 @@ import asyncio
import json import json
import random import random
import time import time
from dataclasses import dataclass
from datetime import datetime
from typing import AsyncGenerator, List, Tuple from typing import AsyncGenerator, List, Tuple
import aiohttp
import numpy as np import numpy as np
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
# (prompt len, output len, latency) from backend_request_func import (
REQUEST_LATENCY: List[Tuple[int, int, float]] = [] ASYNC_REQUEST_FUNCS,
RequestFuncInput,
RequestFuncOutput,
)
@dataclass
class BenchmarkMetrics:
completed: int
total_input: int
total_output: int
request_throughput: float
input_throughput: float
output_throughput: float
mean_ttft_ms: float
median_ttft_ms: float
p99_ttft_ms: float
mean_tpot_ms: float
median_tpot_ms: float
p99_tpot_ms: float
def sample_requests( def sample_requests(
...@@ -46,6 +66,11 @@ def sample_requests( ...@@ -46,6 +66,11 @@ def sample_requests(
dataset = [(data["conversations"][0]["value"], dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset] data["conversations"][1]["value"]) for data in dataset]
# some of these will be filtered out, so sample more than we need
sampled_indices = random.sample(range(len(dataset)),
int(num_requests * 1.2))
dataset = [dataset[i] for i in sampled_indices]
# Tokenize the prompts and completions. # Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset] prompts = [prompt for prompt, _ in dataset]
prompt_token_ids = tokenizer(prompts).input_ids prompt_token_ids = tokenizer(prompts).input_ids
...@@ -92,80 +117,125 @@ async def get_request( ...@@ -92,80 +117,125 @@ async def get_request(
await asyncio.sleep(interval) await asyncio.sleep(interval)
async def send_request(backend: str, model: str, api_url: str, prompt: str, def calculate_metrics(
prompt_len: int, output_len: int, best_of: int, input_requests: List[Tuple[str, int, int]],
use_beam_search: bool, pbar: tqdm) -> None: outputs: List[RequestFuncOutput],
request_start_time = time.perf_counter() dur_s: float,
tokenizer: PreTrainedTokenizerBase,
headers = {"User-Agent": "Benchmark Client"} ) -> BenchmarkMetrics:
if backend == "vllm": total_output = 0
pload = { total_input = 0
"prompt": prompt, completed = 0
"n": 1, per_token_latencies = []
"best_of": best_of, ttfts = []
"use_beam_search": use_beam_search, for i in range(len(outputs)):
"temperature": 0.0 if use_beam_search else 1.0, if outputs[i].success:
"top_p": 1.0, output_len = len(tokenizer.encode(outputs[i].generated_text))
"max_tokens": output_len, total_output += output_len
"ignore_eos": True, total_input += input_requests[i][1]
"stream": False, per_token_latencies.append(outputs[i].latency / output_len)
} ttfts.append(outputs[i].ttft)
if model is not None: completed += 1
pload["model"] = model
elif backend == "tgi":
assert not use_beam_search
params = {
"best_of": best_of,
"max_new_tokens": output_len,
"do_sample": True,
}
pload = {
"inputs": prompt,
"parameters": params,
}
else:
raise ValueError(f"Unknown backend: {backend}")
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout) as session:
while True:
async with session.post(api_url, headers=headers,
json=pload) as response:
chunks = []
async for chunk, _ in response.content.iter_chunks():
chunks.append(chunk)
output = b"".join(chunks).decode("utf-8")
output = json.loads(output)
# Re-send the request if it failed. metrics = BenchmarkMetrics(
if "error" not in output: completed=completed,
break total_input=total_input,
total_output=total_output,
request_throughput=completed / dur_s,
input_throughput=total_input / dur_s,
output_throughput=total_output / dur_s,
mean_ttft_ms=np.mean(ttfts) * 1000,
median_ttft_ms=np.median(ttfts) * 1000,
p99_ttft_ms=np.percentile(ttfts, 99) * 1000,
mean_tpot_ms=np.mean(per_token_latencies) * 1000,
median_tpot_ms=np.median(per_token_latencies) * 1000,
p99_tpot_ms=np.percentile(per_token_latencies, 99) * 1000,
)
request_end_time = time.perf_counter() return metrics
request_latency = request_end_time - request_start_time
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
pbar.update(1)
async def benchmark( async def benchmark(
backend: str, backend: str,
model: str,
api_url: str, api_url: str,
model_id: str,
tokenizer: PreTrainedTokenizerBase,
input_requests: List[Tuple[str, int, int]], input_requests: List[Tuple[str, int, int]],
best_of: int, best_of: int,
use_beam_search: bool, use_beam_search: bool,
request_rate: float, request_rate: float,
) -> None: disable_tqdm: bool,
tasks: List[asyncio.Task] = [] ):
pbar = tqdm(total=len(input_requests)) if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS.get(backend)
else:
raise ValueError(f"Unknown backend: {backend}")
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
print(f"Traffic request rate: {request_rate}")
benchmark_start_time = time.perf_counter()
tasks = []
async for request in get_request(input_requests, request_rate): async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request prompt, prompt_len, output_len = request
task = asyncio.create_task( request_func_input = RequestFuncInput(
send_request(backend, model, api_url, prompt, prompt_len, model=model_id,
output_len, best_of, use_beam_search, pbar)) prompt=prompt,
tasks.append(task) api_url=api_url,
await asyncio.gather(*tasks) prompt_len=prompt_len,
pbar.close() output_len=output_len,
best_of=best_of,
use_beam_search=use_beam_search,
)
tasks.append(
asyncio.create_task(
request_func(request_func_input=request_func_input,
pbar=pbar)))
outputs = await asyncio.gather(*tasks)
if not disable_tqdm:
pbar.close()
benchmark_duration = time.perf_counter() - benchmark_start_time
metrics = calculate_metrics(
input_requests=input_requests,
outputs=outputs,
dur_s=benchmark_duration,
tokenizer=tokenizer,
)
print(f"Successful requests: {metrics.completed}")
print(f"Benchmark duration: {benchmark_duration:2f} s")
print(f"Total input tokens: {metrics.total_input}")
print(f"Total generated tokens: {metrics.total_output}")
print(f"Request throughput: {metrics.request_throughput:.2f} requests/s")
print(f"Input token throughput: {metrics.input_throughput:.2f} tokens/s")
print(f"Output token throughput: {metrics.output_throughput:.2f} tokens/s")
print(f"Mean TTFT: {metrics.mean_ttft_ms:.2f} ms")
print(f"Median TTFT: {metrics.median_ttft_ms:.2f} ms")
print(f"P99 TTFT: {metrics.p99_ttft_ms:.2f} ms")
print(f"Mean TPOT: {metrics.mean_tpot_ms:.2f} ms")
print(f"Median TPOT: {metrics.median_tpot_ms:.2f} ms")
print(f"P99 TPOT: {metrics.p99_tpot_ms:.2f} ms")
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"request_inthroughput": metrics.request_throughput,
"input_throughput": metrics.input_throughput,
"output_throughput": metrics.output_throughput,
"mean_ttft_ms": metrics.mean_ttft_ms,
"median_ttft_ms": metrics.median_ttft_ms,
"p99_ttft_ms": metrics.p99_ttft_ms,
"mean_tpot_ms": metrics.mean_tpot_ms,
"median_tpot_ms": metrics.median_tpot_ms,
"p99_tpot_ms": metrics.p99_tpot_ms
}
return result
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
...@@ -173,77 +243,145 @@ def main(args: argparse.Namespace): ...@@ -173,77 +243,145 @@ def main(args: argparse.Namespace):
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
api_url = f"{args.protocol}://{args.host}:{args.port}{args.endpoint}" backend = args.backend
tokenizer = get_tokenizer(args.tokenizer, model_id = args.model
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}"
else:
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
tokenizer = get_tokenizer(tokenizer_id,
trust_remote_code=args.trust_remote_code) trust_remote_code=args.trust_remote_code)
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
benchmark_start_time = time.perf_counter() benchmark_result = asyncio.run(
asyncio.run( benchmark(
benchmark(args.backend, args.model, api_url, input_requests, backend=backend,
args.best_of, args.use_beam_search, args.request_rate)) api_url=api_url,
benchmark_end_time = time.perf_counter() model_id=model_id,
benchmark_time = benchmark_end_time - benchmark_start_time tokenizer=tokenizer,
print(f"Total time: {benchmark_time:.2f} s") input_requests=input_requests,
print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s") best_of=args.best_of,
use_beam_search=args.use_beam_search,
# Compute the latency statistics. request_rate=args.request_rate,
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY]) disable_tqdm=args.disable_tqdm,
print(f"Average latency: {avg_latency:.2f} s") ))
avg_per_token_latency = np.mean([
latency / (prompt_len + output_len) # Save config and results to json
for prompt_len, output_len, latency in REQUEST_LATENCY if args.save_result:
]) result_json = {}
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
avg_per_output_token_latency = np.mean( # Setup
[latency / output_len for _, output_len, latency in REQUEST_LATENCY]) current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
print("Average latency per output token: " result_json["date"] = current_dt
f"{avg_per_output_token_latency:.2f} s") result_json["backend"] = backend
result_json["version"] = args.version
result_json["model_id"] = model_id
result_json["tokenizer_id"] = tokenizer_id
result_json["best_of"] = args.best_of
result_json["use_beam_search"] = args.use_beam_search
result_json["num_prompts"] = args.num_prompts
# Traffic
result_json["request_rate"] = (
args.request_rate if args.request_rate < float("inf") else "inf")
# Merge with benchmark result
result_json = {**result_json, **benchmark_result}
# Save to file
base_model_id = model_id.split("/")[-1]
file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
with open(file_name, "w") as outfile:
json.dump(result_json, outfile)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Benchmark the online serving throughput.") description="Benchmark the online serving throughput.")
parser.add_argument("--backend", parser.add_argument(
type=str, "--backend",
default="vllm", type=str,
choices=["vllm", "tgi"]) default="vllm",
parser.add_argument("--protocol", choices=list(ASYNC_REQUEST_FUNCS.keys()),
type=str, )
default="http", parser.add_argument(
choices=["http", "https"]) "--version",
type=str,
default="N/A",
help="Version of the serving backend/engine.",
)
parser.add_argument(
"--base-url",
type=str,
default=None,
help="Server or API base url if not using http host and port.",
)
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--endpoint", type=str, default="/generate") parser.add_argument(
parser.add_argument("--model", type=str, default=None) "--endpoint",
type=str,
default="/generate",
help="API endpoint.",
)
parser.add_argument("--dataset", parser.add_argument("--dataset",
type=str, type=str,
required=True, required=True,
help="Path to the dataset.") help="Path to the dataset.")
parser.add_argument("--tokenizer", parser.add_argument(
type=str, "--model",
required=True, type=str,
help="Name or path of the tokenizer.") required=True,
parser.add_argument("--best-of", help="Name of the model.",
type=int, )
default=1, parser.add_argument(
help="Generates `best_of` sequences per prompt and " "--tokenizer",
"returns the best one.") type=str,
help=
"Name or path of the tokenizer, if not using the default model tokenizer.",
)
parser.add_argument(
"--best-of",
type=int,
default=1,
help="Generates `best_of` sequences per prompt and "
"returns the best one.",
)
parser.add_argument("--use-beam-search", action="store_true") parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", parser.add_argument(
type=int, "--num-prompts",
default=1000, type=int,
help="Number of prompts to process.") default=1000,
parser.add_argument("--request-rate", help="Number of prompts to process.",
type=float, )
default=float("inf"), parser.add_argument(
help="Number of requests per second. If this is inf, " "--request-rate",
"then all the requests are sent at time 0. " type=float,
"Otherwise, we use Poisson process to synthesize " default=float("inf"),
"the request arrival times.") help="Number of requests per second. If this is inf, "
"then all the requests are sent at time 0. "
"Otherwise, we use Poisson process to synthesize "
"the request arrival times.",
)
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--trust-remote-code', parser.add_argument(
action='store_true', "--trust-remote-code",
help='trust remote code from huggingface') action="store_true",
help="Trust remote code from huggingface",
)
parser.add_argument(
"--disable-tqdm",
action="store_true",
help="Specify to disbale tqdm progress bar.",
)
parser.add_argument(
"--save-result",
action="store_true",
help="Specify to save benchmark results to a json file",
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -72,6 +72,7 @@ def run_vllm( ...@@ -72,6 +72,7 @@ def run_vllm(
max_model_len: Optional[int], max_model_len: Optional[int],
enforce_eager: bool, enforce_eager: bool,
kv_cache_dtype: str, kv_cache_dtype: str,
device: str,
) -> float: ) -> float:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM( llm = LLM(
...@@ -85,6 +86,7 @@ def run_vllm( ...@@ -85,6 +86,7 @@ def run_vllm(
max_model_len=max_model_len, max_model_len=max_model_len,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
device=device,
) )
# Add the requests to the engine. # Add the requests to the engine.
...@@ -209,7 +211,7 @@ def main(args: argparse.Namespace): ...@@ -209,7 +211,7 @@ def main(args: argparse.Namespace):
args.seed, args.n, args.use_beam_search, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager, args.max_model_len, args.enforce_eager,
args.kv_cache_dtype) args.kv_cache_dtype, args.device)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
...@@ -294,6 +296,12 @@ if __name__ == "__main__": ...@@ -294,6 +296,12 @@ if __name__ == "__main__":
default="auto", default="auto",
help= help=
'Data type for kv cache storage. If "auto", will use model data type.') 'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.')
args = parser.parse_args() args = parser.parse_args()
if args.tokenizer is None: if args.tokenizer is None:
args.tokenizer = args.model args.tokenizer = args.model
......
...@@ -25,18 +25,20 @@ def main( ...@@ -25,18 +25,20 @@ def main(
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
do_profile: bool, do_profile: bool,
device: str = "cuda",
kv_cache_dtype: Optional[str] = None, kv_cache_dtype: Optional[str] = None,
) -> None: ) -> None:
random.seed(seed) random.seed(seed)
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
query = torch.empty(num_seqs, query = torch.empty(num_seqs,
num_query_heads, num_query_heads,
head_size, head_size,
dtype=dtype, dtype=dtype,
device="cuda") device=device)
query.uniform_(-scale, scale) query.uniform_(-scale, scale)
assert num_query_heads % num_kv_heads == 0 assert num_query_heads % num_kv_heads == 0
...@@ -44,11 +46,11 @@ def main( ...@@ -44,11 +46,11 @@ def main(
if use_alibi: if use_alibi:
alibi_slopes = torch.randn(num_query_heads, alibi_slopes = torch.randn(num_query_heads,
dtype=torch.float, dtype=torch.float,
device="cuda") device=device)
context_lens = [context_len for _ in range(num_seqs)] context_lens = [context_len for _ in range(num_seqs)]
max_context_len = max(context_lens) max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") context_lens = torch.tensor(context_lens, dtype=torch.int, device=device)
# Create the block tables. # Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
...@@ -59,12 +61,17 @@ def main( ...@@ -59,12 +61,17 @@ def main(
for _ in range(max_num_blocks_per_seq) for _ in range(max_num_blocks_per_seq)
] ]
block_tables.append(block_table) block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
# Create the KV cache. # Create the KV cache.
key_caches, value_caches = create_kv_caches_with_random( key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype, block_size,
dtype) 1,
num_kv_heads,
head_size,
kv_cache_dtype,
dtype,
device=device)
key_cache, value_cache = key_caches[0], value_caches[0] key_cache, value_cache = key_caches[0], value_caches[0]
# Prepare for the paged attention kernel. # Prepare for the paged attention kernel.
...@@ -84,7 +91,7 @@ def main( ...@@ -84,7 +91,7 @@ def main(
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
def run_benchmark(num_iters: int, profile: bool = False) -> float: def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.synchronize() torch.cuda.synchronize()
if profile: if profile:
torch.cuda.cudart().cudaProfilerStart() torch.cuda.cudart().cudaProfilerStart()
...@@ -135,6 +142,7 @@ def main( ...@@ -135,6 +142,7 @@ def main(
# Warmup. # Warmup.
print("Warming up...") print("Warming up...")
run_benchmark = run_cuda_benchmark
run_benchmark(num_iters=3, profile=False) run_benchmark(num_iters=3, profile=False)
# Benchmark. # Benchmark.
...@@ -175,6 +183,7 @@ if __name__ == '__main__': ...@@ -175,6 +183,7 @@ if __name__ == '__main__':
default="auto", default="auto",
help= help=
'Data type for kv cache storage. If "auto", will use model data type.') 'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument("--device", type=str, choices=["cuda"], default="cuda")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -6,7 +6,7 @@ TOKENS=$2 ...@@ -6,7 +6,7 @@ TOKENS=$2
docker run --gpus all --shm-size 1g -p $PORT:80 \ docker run --gpus all --shm-size 1g -p $PORT:80 \
-v $PWD/data:/data \ -v $PWD/data:/data \
ghcr.io/huggingface/text-generation-inference:0.8 \ ghcr.io/huggingface/text-generation-inference:1.4.0 \
--model-id $MODEL \ --model-id $MODEL \
--sharded false \ --sharded false \
--max-input-length 1024 \ --max-input-length 1024 \
......
...@@ -25,7 +25,9 @@ ...@@ -25,7 +25,9 @@
#include "attention_dtypes.h" #include "attention_dtypes.h"
#include "attention_utils.cuh" #include "attention_utils.cuh"
#ifdef ENABLE_FP8_E5M2
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh" #include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
#endif
#include <algorithm> #include <algorithm>
......
...@@ -4,13 +4,20 @@ ...@@ -4,13 +4,20 @@
#include "cuda_compat.h" #include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#ifdef ENABLE_FP8_E5M2
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh" #include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
#endif
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <map> #include <map>
#include <vector> #include <vector>
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16;
#endif
void swap_blocks( void swap_blocks(
torch::Tensor& src, torch::Tensor& src,
torch::Tensor& dst, torch::Tensor& dst,
......
#include "moe_ops.h"
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
}
#pragma once
#include <torch/extension.h>
void topk_softmax(
torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);
/*
* Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
* Copyright (c) 2024, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
namespace vllm {
namespace moe {
static constexpr int WARP_SIZE = 32;
/// Aligned array type
template <
typename T,
/// Number of elements in the array
int N,
/// Alignment requirement in bytes
int Alignment = sizeof(T) * N
>
class alignas(Alignment) AlignedArray {
float data[N];
};
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
template <int TPB>
__launch_bounds__(TPB) __global__
void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
{
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
__shared__ float normalizing_factor;
__shared__ float float_max;
const int thread_row_offset = blockIdx.x * num_cols;
cub::Sum sum;
float threadData(-FLT_MAX);
// Don't touch finished rows.
if ((finished != nullptr) && finished[blockIdx.x])
{
return;
}
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
threadData = max(static_cast<float>(input[idx]), threadData);
}
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0)
{
float_max = maxElem;
}
__syncthreads();
threadData = 0;
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
threadData += exp((static_cast<float>(input[idx]) - float_max));
}
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
if (threadIdx.x == 0)
{
normalizing_factor = 1.f / Z;
}
__syncthreads();
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
output[idx] = val;
}
}
template <int TPB>
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
{
using cub_kvp = cub::KeyValuePair<int, float>;
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
cub_kvp thread_kvp;
cub::ArgMax arg_max;
const int num_rows = gridDim.x;
const int block_row = blockIdx.x;
const bool row_is_active = finished ? !finished[block_row] : true;
const int thread_read_offset = blockIdx.x * num_experts;
for (int k_idx = 0; k_idx < k; ++k_idx)
{
thread_kvp.key = 0;
thread_kvp.value = -1.f; // This is OK because inputs are probabilities
cub_kvp inp_kvp;
for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
{
const int idx = thread_read_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = inputs_after_softmax[idx];
for (int prior_k = 0; prior_k < k_idx; ++prior_k)
{
const int prior_winning_expert = indices[k * block_row + prior_k];
if (prior_winning_expert == expert)
{
inp_kvp = thread_kvp;
}
}
thread_kvp = arg_max(inp_kvp, thread_kvp);
}
const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0)
{
// Ignore experts the node isn't responsible for with expert parallelism
const int expert = result_kvp.key;
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
const bool should_process_row = row_is_active && node_uses_expert;
const int idx = k * block_row + k_idx;
output[idx] = result_kvp.value;
indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
assert(indices[idx] >= 0);
source_rows[idx] = k_idx * num_rows + block_row;
}
__syncthreads();
}
}
// ====================== TopK softmax things ===============================
/*
A Top-K gating softmax written to exploit when the number of experts in the MoE layers
are a small power of 2. This allows us to cleanly share the rows among the threads in
a single warp and eliminate communication between warps (so no need to use shared mem).
It fuses the softmax, max and argmax into a single kernel.
Limitations:
1) This implementation is intended for when the number of experts is a small power of 2.
2) This implementation assumes k is small, but will work for any k.
*/
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
int* source_rows, const int k, const int start_expert, const int end_expert)
{
// We begin by enforcing compile time assertions and setting up compile time constants.
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
// Number of bytes each thread pulls in per load
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
// Restrictions based on previous section.
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
// We have NUM_EXPERTS elements per row. We specialize for small #experts
static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
// Restrictions for previous section.
static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp");
// ===================== From this point, we finally start computing run-time variables. ========================
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
// This, each block processes a chunk of rows. We start by computing the start row for each block.
const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
// Now, using the base row per thread block, we compute the base row per warp.
const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
// The threads in a warp are split into sub-groups that will work on a row.
// We compute row offset for each thread sub-group
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
const int thread_row = warp_base_row + thread_row_in_warp;
// Threads with indices out of bounds should early exit here.
if (thread_row >= num_rows)
{
return;
}
const bool row_is_active = finished ? !finished[thread_row] : true;
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
// row it will read.
const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
// this can support all powers of 2 up to 16.
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
using AccessType = AlignedArray<float, ELTS_PER_LDG>;
// Finally, we pull in the data from global mem
float row_chunk[VPT];
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
{
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
}
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
// convert to float afterwards for the exp + sum reduction.
float thread_max = row_chunk[0];
#pragma unroll
for (int ii = 1; ii < VPT; ++ii)
{
thread_max = max(thread_max, row_chunk[ii]);
}
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
}
// From this point, thread max in all the threads have the max within the row.
// Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
float row_sum = 0;
#pragma unroll
for (int ii = 0; ii < VPT; ++ii)
{
row_chunk[ii] = expf(row_chunk[ii] - thread_max);
row_sum += row_chunk[ii];
}
// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
}
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
// respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
// compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
// However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
// argmax after computing the softmax.
const float reciprocal_row_sum = 1.f / row_sum;
#pragma unroll
for (int ii = 0; ii < VPT; ++ii)
{
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
}
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
// with the max index.
int start_col = first_elt_read_by_thread;
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
for (int k_idx = 0; k_idx < k; ++k_idx)
{
// First, each thread does the local argmax
float max_val = row_chunk[0];
int expert = start_col;
#pragma unroll
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG)
{
#pragma unroll
for (int ii = 0; ii < ELTS_PER_LDG; ++ii)
{
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
// No check on the experts here since columns with the smallest index are processed first and only
// updated if > (not >=)
if (val > max_val)
{
max_val = val;
expert = col + ii;
}
}
}
// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.
// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can
// then blank out their max with -inf and the warp can run more iterations...
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
// We want lower indices to "win" in every thread so we break ties this way
if (other_max > max_val || (other_max == max_val && other_expert < expert))
{
max_val = other_max;
expert = other_expert;
}
}
// Write the max for this k iteration to global memory.
if (thread_group_idx == 0)
{
// Add a guard to ignore experts not included by this node
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
const bool should_process_row = row_is_active && node_uses_expert;
// The lead thread from each sub-group will write out the final results to global memory. (This will be a
// single) thread per row of the input/output matrices.
const int idx = k * thread_row + k_idx;
output[idx] = max_val;
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
source_rows[idx] = k_idx * num_rows + thread_row;
}
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
if (k_idx + 1 < k)
{
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
// Only the thread in the group which produced the max will reset the "winning" value to -inf.
if (thread_group_idx == thread_to_clear_in_group)
{
const int offset_for_expert = expert % ELTS_PER_LDG;
// Safe to set to any negative value since row_chunk values must be between 0 and 1.
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
}
}
}
}
namespace detail
{
// Constructs some constants needed to partition the work across threads at compile time.
template <int EXPERTS, int BYTES_PER_LDG>
struct TopkConstants
{
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
};
} // namespace detail
template <int EXPERTS, int WARPS_PER_TB>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
{
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
}
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
gating_output, nullptr, topk_weights, topk_indicies, \
token_expert_indices, num_tokens, topk, 0, num_experts, \
stream);
void topkGatingSoftmaxKernelLauncher(
const float* gating_output,
float* topk_weights,
int* topk_indicies,
int* token_expert_indices,
float* softmax_workspace,
const int num_tokens,
const int num_experts,
const int topk,
cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4;
switch (num_experts) {
case 1:
LAUNCH_SOFTMAX(1, WARPS_PER_TB);
break;
case 2:
LAUNCH_SOFTMAX(2, WARPS_PER_TB);
break;
case 4:
LAUNCH_SOFTMAX(4, WARPS_PER_TB);
break;
case 8:
LAUNCH_SOFTMAX(8, WARPS_PER_TB);
break;
case 16:
LAUNCH_SOFTMAX(16, WARPS_PER_TB);
break;
case 32:
LAUNCH_SOFTMAX(32, WARPS_PER_TB);
break;
case 64:
LAUNCH_SOFTMAX(64, WARPS_PER_TB);
break;
case 128:
LAUNCH_SOFTMAX(128, WARPS_PER_TB);
break;
case 256:
LAUNCH_SOFTMAX(256, WARPS_PER_TB);
break;
default: {
TORCH_CHECK(softmax_workspace != nullptr,
"softmax_workspace must be provided for num_experts that are not a power of 2.");
static constexpr int TPB = 256;
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
gating_output, nullptr, softmax_workspace, num_experts);
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices,
num_experts, topk, 0, num_experts);
}
}
}
} // namespace moe
} // namespace vllm
void topk_softmax(
torch::Tensor& topk_weights, // [num_tokens, topk]
torch::Tensor& topk_indices, // [num_tokens, topk]
torch::Tensor& token_expert_indices, // [num_tokens, topk]
torch::Tensor& gating_output) // [num_tokens, num_experts]
{
const int num_experts = gating_output.size(-1);
const int num_tokens = gating_output.numel() / num_experts;
const int topk = topk_weights.size(-1);
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
const bool needs_workspace = !is_pow_2 || num_experts > 256;
const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
...@@ -48,8 +48,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -48,8 +48,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&rotary_embedding, &rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
#ifndef USE_ROCM
// Quantization ops // Quantization ops
#ifndef USE_ROCM
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
#endif #endif
......
...@@ -27,72 +27,85 @@ __pack_half2(const half x, const half y) { ...@@ -27,72 +27,85 @@ __pack_half2(const half x, const half y) {
return (v1 << 16) | v0; return (v1 << 16) | v0;
} }
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) template<int N>
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
int G,
int split_k_iters,
half* __restrict__ A,
int* __restrict__ B,
half* __restrict__ scaling_factors,
int* __restrict__ zeros,
int M,
int IC,
int OC,
half* __restrict__ C)
{ {
// Only support matrix n = 64 or 128
assert(N == 64 || N == 128);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
assert(false); assert(false);
#else #else
static constexpr uint32_t ZERO = 0x0; static constexpr uint32_t ZERO = 0x0;
float C_warp[32]; float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)]; __shared__ half A_shared[16 * (32 + 8)];
__shared__ half B_shared[32 * (128 + 8)]; __shared__ half B_shared[32 * (N + 8)];
__shared__ half scaling_factors_shared[128];
__shared__ half zeros_shared[128];
int j_factors1 = ((OC + 128 - 1) / 128); __shared__ half scaling_factors_shared[N];
__shared__ half zeros_shared[N];
int j_factors1 = ((OC + N - 1) / N);
int blockIdx_x = 0; int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
half A_shared_warp[8]; half A_shared_warp[8];
half B_shared_warp[32]; half B_shared_warp[N / 4];
for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) { for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) {
for (int i = 0; i < 8; ++i) { for (int i = 0; i < 8; ++i) {
C_warp[(j_0_4_init * 8) + i] = 0.0; C_warp[(j_0_4_init * 8) + i] = 0.0;
} }
} }
static constexpr int row_stride_warp = 32 * 8 / 32; static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride = 2 * 32 * 8 / 128; static constexpr int row_stride = 2 * 32 * 8 / N;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128; bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M; // bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A half* A_ptr = A
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ (((int)threadIdx.x) % (32 / 8)) * 8; + (((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B int* B_ptr = B
+ ((int)threadIdx.y) * (OC / 8) * 2 + ((int)threadIdx.y) * (OC / 8) * (256 / N)
+ (((int)threadIdx.x) / (128 / 8)) * (OC / 8) + (((int)threadIdx.x) / (N / 8)) * (OC / 8)
+ (((int)blockIdx_y) % j_factors1) * (128 / 8) + (((int)blockIdx_y) % j_factors1) * (N / 8)
+ (((int)threadIdx.x) % (128 / 8)) * 1; + (((int)threadIdx.x) % (N / 8)) * 1;
// Why * 1 in the above line? // Why * 1 in the above line?
half* A_shared_ptr = A_shared half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8) + ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8) ) * 8; + (((int)threadIdx.x) % (32 / 8) ) * 8;
half* B_shared_ptr = B_shared half* B_shared_ptr = B_shared
+ ((int)threadIdx.y) * (row_stride / 2) * (128 + 8) + ((int)threadIdx.y) * (row_stride / 2) * (N + 8)
+ (((int)threadIdx.x) / (128 / 8)) * (128 + 8) + (((int)threadIdx.x) / (N / 8)) * (N + 8)
+ (((int)threadIdx.x) % (128 / 8)) * 8; + (((int)threadIdx.x) % (N / 8)) * 8;
int* zeros_ptr = zeros int* zeros_ptr = zeros
+ (((int)blockIdx_y) % j_factors1) * (128 / 8) + (((int)blockIdx_y) % j_factors1) * (N / 8)
+ ((int)threadIdx.x) % (128 / 8); + ((int)threadIdx.x) % (N / 8);
half* scaling_factors_ptr = scaling_factors half* scaling_factors_ptr = scaling_factors
+ (((int)blockIdx_y) % j_factors1) * (128) + (((int)blockIdx_y) % j_factors1) * N
+ (((int)threadIdx.x) % (128 / 8)) * 8; + (((int)threadIdx.x) % (N / 8)) * 8;
half* C_ptr = C half* C_ptr = C
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 128 + (((int)blockIdx_y) % j_factors1) * N
+ ((int)threadIdx.y) * 64 + ((int)threadIdx.y) * (N / 2)
+ (((int)threadIdx.x) % 4) * 2; + (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros // preload s.f. and zeros
...@@ -123,13 +136,13 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i ...@@ -123,13 +136,13 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) { for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16 // B: 32 x 136 (128+8) float16
// each warp: 32 x 4 // each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N) // row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
...@@ -152,7 +165,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i ...@@ -152,7 +165,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
*/ */
// write back // write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16; *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16;
} }
__syncthreads(); __syncthreads();
...@@ -174,13 +187,13 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i ...@@ -174,13 +187,13 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
); );
} }
for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) { for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
{ {
unsigned int addr; unsigned int addr;
__asm__ __volatile__( __asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr) : "=r"(addr)
: "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8)))) : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8))))
); );
__asm__ __volatile__( __asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
...@@ -190,7 +203,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i ...@@ -190,7 +203,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
); );
} }
} }
for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) { for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
{ {
__asm__ __volatile__( __asm__ __volatile__(
...@@ -258,241 +271,6 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i ...@@ -258,241 +271,6 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
#endif #endif
} }
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
assert(false);
#else
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
__shared__ half B_shared[32 * (64 + 8)];
__shared__ half scaling_factors_shared[64];
__shared__ half zeros_shared[64];
int j_factors1 = ((OC + 64 - 1) / 64);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
half A_shared_warp[8];
half B_shared_warp[16];
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
for (int i = 0; i < 8; ++i) {
C_warp[(j_0_4_init * 8) + i] = 0.0;
}
}
static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride = 2 * 32 * 8 / 64;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B
+ ((int)threadIdx.y) * (OC / 8) * 4
+ (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
+ (((int)threadIdx.x) % (64 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
half* B_shared_ptr = B_shared
+ ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
+ (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
+ (((int)threadIdx.x) % (64 / 8)) * 8;
int* zeros_ptr = zeros
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
+ ((int)threadIdx.x) % (64 / 8);
half* scaling_factors_ptr = scaling_factors
+ (((int)blockIdx_y) % j_factors1) * (64)
+ (((int)threadIdx.x) % (64 / 8)) * 8;
half* C_ptr = C
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 64
+ ((int)threadIdx.y) * 32
+ (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if (ld_A_flag)
{
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
}
else
{
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
// write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
}
__syncthreads();
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1)
{
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
: "r"(addr)
);
}
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0)
{
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
: "r"(addr)
);
}
}
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
#else
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
#endif
}
}
}
// TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
if (row_offset < M)
{
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
}
}
}
#endif
}
__global__ void __launch_bounds__(64) dequantize_weights( __global__ void __launch_bounds__(64) dequantize_weights(
int* __restrict__ B, int* __restrict__ B,
half* __restrict__ scaling_factors, half* __restrict__ scaling_factors,
...@@ -526,26 +304,24 @@ __global__ void __launch_bounds__(64) dequantize_weights( ...@@ -526,26 +304,24 @@ __global__ void __launch_bounds__(64) dequantize_weights(
int index4 = 8 * col + (int)(row / G) * N * 8; int index4 = 8 * col + (int)(row / G) * N * 8;
half* scaling_factors_ptr2 = scaling_factors + index4; half* scaling_factors_ptr2 = scaling_factors + index4;
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2); uint32_t B_loaded = *(uint32_t*)B_ptr2;
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
int j=0; asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
uint32_t B_loaded = *(uint32_t*)(B_ptr2 + j); *(uint4*)B_shared_ptr2 = B_loaded_fp16;
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
*(uint4*)(B_shared_ptr2 + j) = B_loaded_fp16;
for (int i=0; i<8; ++i) { for (int i = 0; i < 8; ++i) {
*(C_ptr2 + i) = B_shared[i]; *(C_ptr2 + i) = B_shared[i];
} }
} }
...@@ -650,19 +426,21 @@ torch::Tensor awq_gemm( ...@@ -650,19 +426,21 @@ torch::Tensor awq_gemm(
// threadIdx.x: 32 // threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2] // threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2); dim3 threads_per_block(32, 2);
vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block, 0, stream>>>( vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
num_out_channels, out_feats);
} }
else if (num_out_channels % 64 == 0) else if (num_out_channels % 64 == 0)
{ {
int j_factors1 = num_out_channels / 64 / 1; int j_factors1 = num_out_channels / 64 / 1;
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32 // threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2] // threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2); dim3 threads_per_block(32, 2);
vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block, 0, stream>>>( vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
num_out_channels, out_feats);
} }
return _out_feats.sum(0); return _out_feats.sum(0);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment