Commit 6640dc0b authored by zhuwenwen's avatar zhuwenwen
Browse files
parents 44d4d334 83e4e0fe
......@@ -48,6 +48,7 @@ steps:
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s spec_decode/e2e/test_integration_dist.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
- label: Distributed Tests (Multiple Groups)
#mirror_hardwares: [amd]
......
......@@ -7,7 +7,7 @@ steps:
queue: cpu_queue
commands:
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
- "docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --tag {{ docker_image }} --target test --progress plain ."
- "docker push {{ docker_image }}"
env:
DOCKER_BUILDKIT: "1"
......@@ -19,6 +19,34 @@ steps:
limit: 5
- wait
- group: "AMD Tests"
depends_on: ~
steps:
{% for step in steps %}
{% if step.mirror_hardwares and "amd" in step.mirror_hardwares %}
- label: "AMD: {{ step.label }}"
agents:
queue: amd
command: bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" ; ")) | safe }}"
env:
DOCKER_BUILDKIT: "1"
soft_fail: true
{% endif %}
{% endfor %}
- label: "Neuron Test"
depends_on: ~
agents:
queue: neuron
command: bash .buildkite/run-neuron-test.sh
soft_fail: false
- label: "Intel Test"
depends_on: ~
agents:
queue: intel
command: bash .buildkite/run-cpu-test.sh
{% for step in steps %}
- label: "{{ step.label }}"
agents:
......@@ -31,7 +59,7 @@ steps:
{% else %}
queue: gpu_1_queue
{% endif %}
soft_fail: true
soft_fail: {{ step.soft_fail or false }}
{% if step.parallelism %}
parallelism: {{ step.parallelism }}
{% endif %}
......
......@@ -25,7 +25,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff==0.1.5 codespell==2.2.6 tomli==2.0.1 isort==5.13.2
pip install ruff==0.1.5 codespell==2.3.0 tomli==2.0.1 isort==5.13.2
- name: Analysing the code with ruff
run: |
ruff .
......
......@@ -179,9 +179,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/custom_all_reduce.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu")
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
#
# The CUTLASS kernels for Hopper require sm90a to be enabled.
......@@ -189,7 +189,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
set_source_files_properties(
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
PROPERTIES
COMPILE_FLAGS
"-gencode arch=compute_90a,code=sm_90a")
......
......@@ -10,7 +10,7 @@
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS dev
RUN apt-get update -y \
&& apt-get install -y python3-pip git
&& apt-get install -y python3-pip git curl sudo
# Workaround for https://github.com/openai/triton/issues/2507 and
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
......@@ -27,6 +27,8 @@ RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements-cuda.txt
# install development dependencies
COPY requirements-lint.txt requirements-lint.txt
COPY requirements-test.txt requirements-test.txt
COPY requirements-dev.txt requirements-dev.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements-dev.txt
......@@ -70,10 +72,28 @@ ENV NVCC_THREADS=$nvcc_threads
# make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1
ARG USE_SCCACHE
# if USE_SCCACHE is set, use sccache to speed up compilation
RUN --mount=type=cache,target=/root/.cache/pip \
if [ "$USE_SCCACHE" = "1" ]; then \
echo "Installing sccache..." \
&& curl -L -o sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz \
&& tar -xzf sccache.tar.gz \
&& sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
&& rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
&& export SCCACHE_BUCKET=vllm-build-sccache \
&& export SCCACHE_REGION=us-west-2 \
&& sccache --show-stats \
&& python3 setup.py bdist_wheel --dist-dir=dist \
&& sccache --show-stats; \
fi
ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/pip \
python3 setup.py bdist_wheel --dist-dir=dist
if [ "$USE_SCCACHE" != "1" ]; then \
python3 setup.py bdist_wheel --dist-dir=dist; \
fi
# check the size of the wheel, we cannot upload wheels larger than 100MB
COPY .buildkite/check-wheel-size.py check-wheel-size.py
......
......@@ -3,9 +3,13 @@
FROM ubuntu:22.04 AS cpu-test-1
RUN apt-get update -y \
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
RUN echo 'export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD' >> ~/.bashrc
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl
RUN pip install --upgrade pip \
&& pip install wheel packaging ninja "setuptools>=49.4.0" numpy
......
ARG NIGHTLY_DATE="20240601"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
FROM $BASE_IMAGE
WORKDIR /workspace
COPY . /workspace/vllm
ENV VLLM_TARGET_DEVICE="tpu"
# Install aiohttp separately to avoid build errors.
RUN pip install aiohttp
# Install the TPU and Pallas dependencies.
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
# Build vLLM.
RUN cd /workspace/vllm && python setup.py develop
CMD ["/bin/bash"]
......@@ -74,7 +74,7 @@ python3 setup.py install
+ 若使用 pip install 下载安装过慢,可添加源:-i https://pypi.tuna.tsinghua.edu.cn/simple/
## 验证
- python -c "import vllm; print(vllm.\_\_version__)",版本号与官方版本同步,查询该软件的版本号,例如0.4.3
- python -c "import vllm; print(vllm.\_\_version__)",版本号与官方版本同步,查询该软件的版本号,例如0.5.0.post1
## Known Issue
-
......
......@@ -23,16 +23,10 @@ If you have cool projects related to vLLM or LLM inference, we would love to see
This will be a great chance for everyone in the community to get together and learn.
Please submit your proposal [here](https://raysummit.anyscale.com/flow/anyscale/raysummit2024/landing/page/eventsite)
**The Fourth vLLM Bay Area Meetup (June 11th 5:30pm-8pm PT)**
We are thrilled to announce our fourth vLLM Meetup!
The vLLM team will share recent updates and roadmap.
We will also have vLLM collaborators from BentoML and Cloudflare coming up to the stage to discuss their experience in deploying LLMs with vLLM.
Please register [here](https://lu.ma/agivllm) and join us!
---
*Latest News* 🔥
- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing).
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
- [2024/01] Added ROCm 6.0 support to vLLM.
......@@ -65,7 +59,7 @@ vLLM is flexible and easy to use with:
- Tensor parallelism support for distributed inference
- Streaming outputs
- OpenAI-compatible API server
- Support NVIDIA GPUs and AMD GPUs
- Support NVIDIA GPUs, AMD GPUs, and Intel CPUs
- (Experimental) Prefix caching support
- (Experimental) Multi-lora support
......
......@@ -68,9 +68,13 @@ async def async_request_tgi(
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk_bytes = chunk_bytes.decode("utf-8")
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data:")
#NOTE: Sometimes TGI returns a ping response without
# any data, we should skip it.
if chunk_bytes.startswith(":"):
continue
chunk = remove_prefix(chunk_bytes, "data:")
data = json.loads(chunk)
timestamp = time.perf_counter()
......
......@@ -189,7 +189,7 @@ if __name__ == '__main__':
"--device",
type=str,
default="cuda",
choices=["cuda", "cpu"],
choices=["cuda", "cpu", "tpu"],
help='device type for vLLM execution, supporting CUDA and CPU.')
parser.add_argument('--block-size',
type=int,
......
......@@ -353,7 +353,7 @@ if __name__ == "__main__":
"--device",
type=str,
default="cuda",
choices=["cuda", "cpu"],
choices=["cuda", "cpu", "tpu"],
help='device type for vLLM execution, supporting CUDA and CPU.')
parser.add_argument(
"--enable-prefix-caching",
......
......@@ -76,11 +76,7 @@ def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
scale_b: torch.tensor,
out_dtype: torch.dtype) -> torch.tensor:
return ops.cutlass_scaled_mm_dq(a,
b,
scale_a,
scale_b,
out_dtype=out_dtype)
return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)
# bench
......
......@@ -33,6 +33,7 @@ function (find_isa CPUINFO TARGET OUT)
endif()
endfunction()
find_isa(${CPUINFO} "avx2" AVX2_FOUND)
find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
if (AVX512_FOUND)
......@@ -53,8 +54,11 @@ if (AVX512_FOUND)
else()
message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
endif()
elseif (AVX2_FOUND)
list(APPEND CXX_COMPILE_FLAGS "-mavx2")
message(WARNING "vLLM CPU backend using AVX2 ISA")
else()
message(FATAL_ERROR "vLLM CPU backend requires AVX512 ISA support.")
message(FATAL_ERROR "vLLM CPU backend requires AVX512 or AVX2 ISA support.")
endif()
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
......
......@@ -5,6 +5,10 @@
#include <immintrin.h>
#include <torch/all.h>
#ifndef __AVX2__
static_assert(false, "AVX2 must be supported for the current implementation.");
#endif
namespace vec_op {
// FIXME: FP16 is not fully supported in Torch-CPU
......@@ -104,6 +108,7 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
};
#ifdef __AVX512F__
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
......@@ -123,6 +128,34 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; }
};
#else
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
__m256i reg_low;
__m256i reg_high;
explicit BF16Vec32(const void *ptr)
: reg_low(_mm256_loadu_si256((__m256i const *)ptr)),
reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {}
explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low),
reg_high(high) {}
explicit BF16Vec32(BF16Vec8 &vec8_data)
: reg_low((__m256i)_mm256_inserti32x4(
_mm256_castsi128_si256((__m128i)vec8_data.reg),
(__m128i)vec8_data.reg, 1)),
reg_high((__m256i)_mm256_inserti32x4(
_mm256_castsi128_si256((__m128i)vec8_data.reg),
(__m128i)vec8_data.reg, 1)) {}
void save(void *ptr) const {
*reinterpret_cast<__m256i *>(ptr) = reg_low;
*reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high;
}
};
#endif
struct FP32Vec4 : public Vec<FP32Vec4> {
constexpr static int VEC_ELEM_NUM = 4;
......@@ -226,6 +259,7 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
};
#ifdef __AVX512F__
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
......@@ -290,6 +324,114 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
};
#else
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
__m256 reg;
float values[8];
};
__m256 reg_low;
__m256 reg_high;
explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)),
reg_high(_mm256_set1_ps(v)) {}
explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)),
reg_high(_mm256_set1_ps(0.0)) {}
explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)),
reg_high(_mm256_loadu_ps(ptr + 8)) {}
explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {}
explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low),
reg_high(data.reg_high) {}
explicit FP32Vec16(const FP32Vec4 &data)
: reg_low((__m256)_mm256_inserti128_si256(
_mm256_castsi128_si256((__m128i)data.reg),
(__m128i)data.reg, 1)),
reg_high((__m256)_mm256_inserti128_si256(
_mm256_castsi128_si256((__m128i)data.reg),
(__m128i)data.reg, 1)) {}
explicit FP32Vec16(const FP32Vec8 &data)
: reg_low(data.reg), reg_high(data.reg) {}
explicit FP32Vec16(const BF16Vec16 &v) {
__m128i low = _mm256_extractf128_si256(v.reg, 0);
__m128i high = _mm256_extractf128_si256(v.reg, 1);
__m256i v_low_epi32 = _mm256_cvtepu16_epi32(low);
__m256i v_high_epi32 = _mm256_cvtepu16_epi32(high);
__m256i v_low_shifted = _mm256_bslli_epi128(v_low_epi32, 2);
__m256i v_high_shifted = _mm256_bslli_epi128(v_high_epi32, 2);
reg_low = _mm256_castsi256_ps(v_low_shifted);
reg_high = _mm256_castsi256_ps(v_high_shifted);
}
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
FP32Vec16 operator*(const FP32Vec16 &b) const {
return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low),
_mm256_mul_ps(reg_high, b.reg_high));
}
FP32Vec16 operator+(const FP32Vec16 &b) const {
return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low),
_mm256_add_ps(reg_high, b.reg_high));
}
FP32Vec16 operator-(const FP32Vec16 &b) const {
return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low),
_mm256_sub_ps(reg_high, b.reg_high));
}
FP32Vec16 operator/(const FP32Vec16 &b) const {
return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low),
_mm256_div_ps(reg_high, b.reg_high));
}
float reduce_sum() const {
FP32Vec8 low = FP32Vec8(reg_low);
FP32Vec8 high = FP32Vec8(reg_high);
return low.reduce_sum() + high.reduce_sum();
}
template <int group_size> float reduce_sub_sum(int idx) {
float sum = 0.0;
static_assert(VEC_ELEM_NUM % group_size == 0);
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
uint32_t mask = base_mask << (idx * group_size);
AliasReg ar;
auto func = [&sum, &mask, &ar](int i) {
int flag = mask & 0x1;
mask = mask >> 1;
if (flag != 0) sum += ar.values[i];
};
ar.reg = reg_low;
unroll_loop<int, 8>(func);
ar.reg = reg_high;
unroll_loop<int, 8>(func);
return sum;
}
void save(float *ptr) const {
_mm256_storeu_ps(ptr, reg_low);
_mm256_storeu_ps(ptr + 8, reg_high);
}
};
#endif
template <typename T> struct VecType { using vec_type = void; };
......@@ -336,6 +478,7 @@ template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
*ptr = *(v_ptr + 1);
}
#ifdef __AVX512F__
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
: reg(_mm256_cvtepi32_epi16(
_mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
......@@ -343,7 +486,27 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
: reg(_mm512_cvtepi32_epi16(
_mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
#endif
#else
namespace{
__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
__m256i ai = _mm256_castps_si256(a);
ai = _mm256_srli_epi32(ai, 16);
ai = _mm256_packus_epi32(ai, ai);
ai = _mm256_permute4x64_epi64(ai, 0b00111001);
return _mm256_extracti128_si256(ai, 0);
}
}
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
: reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {}
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low));
BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high));
reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1);
}
#endif // __AVX512F__
#endif // __AVX512BF16__
inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); }
......
......@@ -90,7 +90,7 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);
void cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
......
......@@ -29,21 +29,14 @@
using namespace cute;
/*
This defines a quantized GEMM operation with dequantized output, similar to
torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace {
......@@ -83,27 +76,25 @@ struct enable_sm89_to_sm90 : Kernel {
}
};
template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_, typename TileShape,
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
struct cutlass_2x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
using Operator =
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
cutlass::arch::OpMultiplyAddSaturate,
cutlass::arch::OpMultiplyAdd>::type;
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
using OutputTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
>;
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue {
private:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
......@@ -123,14 +114,56 @@ struct cutlass_2x_gemm {
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute1 =
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using ScaleAArgs = typename ScaleA::Arguments;
using ScaleBArgs = typename ScaleB::Arguments;
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
typename EVTCompute0::Arguments evt0_compute_args{b_args};
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
return evt_compute_args;
}
};
template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_,
template <typename, typename> typename Epilogue_, typename TileShape,
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
struct cutlass_2x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using Operator =
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
cutlass::arch::OpMultiplyAddSaturate,
cutlass::arch::OpMultiplyAdd>::type;
using OutputTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
>;
using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
using EVTCompute = typename Epilogue::EVTCompute;
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
Stride<int64_t, Int<1>, Int<0>>>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute1>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
// clang-format off
using RowMajor = typename cutlass::layout::RowMajor;
......@@ -153,11 +186,10 @@ struct cutlass_2x_gemm {
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
};
template <typename Gemm>
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
......@@ -177,23 +209,14 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
auto a_scales_ptr = a_scales.data_ptr<float>();
auto b_scales_ptr = b_scales.data_ptr<float>();
using ScaleAArgs = typename Gemm::ScaleA::Arguments;
using ScaleBArgs = typename Gemm::ScaleB::Arguments;
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args};
typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args,
evt0_compute_args};
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
using Epilogue = typename Gemm::Epilogue;
auto evt_args =
Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
typename Gemm::EVTD::Arguments epilogue_args{
evt1_compute_args,
evt_args,
d_args,
};
......@@ -229,7 +252,7 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
} // namespace
void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
......@@ -243,20 +266,20 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
b_scales);
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
b_scales);
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
......@@ -270,20 +293,20 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
......@@ -298,32 +321,32 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
} else {
assert(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
}
} else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::bfloat16_t, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
cutlass::bfloat16_t, ScaledEpilogue, TileShape, WarpShape,
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::half_t, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
cutlass::half_t, ScaledEpilogue, TileShape, WarpShape,
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
}
}
}
......@@ -32,21 +32,14 @@
using namespace cute;
/*
This defines a quantized GEMM operation with dequantized output, similar to
torch._scaled_mm. It is defined using the CUTLASS 3.x API, and is used for
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper) or later.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace {
......@@ -71,21 +64,25 @@ struct enable_sm90_or_later : Kernel {
}
};
template <typename ElementAB_, typename ElementD_, typename TileShape,
typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
struct cutlass_3x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
/*
This epilogue function defines a quantized GEMM operation similar to
torch.scaled_mm_.
using EpilogueDescriptor =
cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
ElementD, EpilogueSchedule>;
A and B may be both either int8 or fp8_e4m3. A can be
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogue {
private:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
......@@ -111,19 +108,53 @@ struct cutlass_3x_gemm {
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute1 =
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using ScaleA_Args = typename ScaleA::Arguments;
using ScaleB_Args = typename ScaleB::Arguments;
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
return ArgumentType{a_args, {b_args}};
}
};
template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
struct cutlass_3x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using EpilogueDescriptor =
cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
ElementD, EpilogueSchedule>;
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
using ElementC = void;
using StrideC = StrideD;
using EVTCompute = typename Epilogue::EVTCompute;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
EpilogueSchedule, EVTCompute1>::CollectiveOp;
EpilogueSchedule, EVTCompute>::CollectiveOp;
static constexpr size_t CEStorageSize =
sizeof(typename CollectiveEpilogue::SharedStorage);
......@@ -148,11 +179,10 @@ struct cutlass_3x_gemm {
struct GemmKernel : public KernelType {};
};
template <typename Gemm>
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
......@@ -182,19 +212,13 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
{}, c_ptr, c_stride, c_ptr, c_stride};
Gemm::Epilogue::prepare_args(
std::forward<EpilogueArgs>(epilogue_params)...),
c_ptr, c_stride, c_ptr, c_stride};
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
prob_shape, mainloop_args, epilogue_args};
using ScaleA_Args = typename Gemm::ScaleA::Arguments;
using ScaleB_Args = typename Gemm::ScaleB::Arguments;
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
args.epilogue.thread = {a_args, {b_args}};
// Launch the CUTLASS GEMM kernel.
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
GemmOp gemm_op;
......@@ -209,7 +233,8 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
CUTLASS_CHECK(status);
}
template <typename InType, typename OutType, int32_t M>
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue, int32_t M>
struct sm90_fp8_config {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
......@@ -219,12 +244,13 @@ struct sm90_fp8_config {
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule>;
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType>
struct sm90_fp8_config<InType, OutType, 128> {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config<InType, OutType, Epilogue, 128> {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
......@@ -233,12 +259,13 @@ struct sm90_fp8_config<InType, OutType, 128> {
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule>;
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType>
struct sm90_fp8_config<InType, OutType, 64> {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config<InType, OutType, Epilogue, 64> {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
......@@ -247,30 +274,28 @@ struct sm90_fp8_config<InType, OutType, 64> {
using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule>;
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
} // namespace
template <typename InType, typename OutType>
void cutlass_scaled_mm_dq_sm90_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
using Cutlass3xGemmDefault =
typename sm90_fp8_config<InType, OutType, 0>::Cutlass3xGemm;
typename sm90_fp8_config<InType, OutType, Epilogue, 0>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_fp8_config<InType, OutType, 64>::Cutlass3xGemm;
typename sm90_fp8_config<InType, OutType, Epilogue, 64>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_fp8_config<InType, OutType, 128>::Cutlass3xGemm;
typename sm90_fp8_config<InType, OutType, Epilogue, 128>::Cutlass3xGemm;
uint32_t const m = a.size(0);
uint32_t const mp2 =
......@@ -278,20 +303,20 @@ void cutlass_scaled_mm_dq_sm90_fp8_dispatch(torch::Tensor& out,
if (mp2 <= 64) {
// m in [1, 64]
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM64>(
out, a, b, a_scales, b_scales);
return cutlass_gemm_caller<Cutlass3xGemmM64>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// m in (64, 128]
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM128>(
out, a, b, a_scales, b_scales);
return cutlass_gemm_caller<Cutlass3xGemmM128>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmDefault>(
out, a, b, a_scales, b_scales);
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
......@@ -308,16 +333,15 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<
cutlass_3x_gemm<int8_t, cutlass::bfloat16_t, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>>(
out, a, b, a_scales, b_scales);
return cutlass_gemm_caller<cutlass_3x_gemm<
int8_t, cutlass::bfloat16_t, ScaledEpilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>>(out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<
cutlass_3x_gemm<int8_t, cutlass::half_t, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>>(
return cutlass_gemm_caller<
cutlass_3x_gemm<int8_t, cutlass::half_t, ScaledEpilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>>(
out, a, b, a_scales, b_scales);
}
} else {
......@@ -325,13 +349,13 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t>(
return cutlass_gemm_sm90_fp8_dispatch<
cutlass::float_e4m3_t, cutlass::bfloat16_t, ScaledEpilogue>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t>(
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
......
......@@ -3,29 +3,29 @@
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm_dq_sm80(torch::Tensor& c, torch::Tensor const& a,
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a,
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
#endif
void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
int32_t major_capability;
......@@ -57,19 +57,19 @@ void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales);
#else
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales);
#endif
} else if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales);
} else if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales);
} else {
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_dq_sm75(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales);
}
}
......@@ -23,8 +23,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
template <typename scalar_t>
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
const scalar_t val, const float scale) {
float x = static_cast<float>(val) / scale;
const scalar_t val, const float inverted_scale) {
float x = static_cast<float>(val) * inverted_scale;
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
return static_cast<c10::Float8_e4m3fn>(r);
}
......@@ -71,15 +71,56 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
}
}
template <typename scalar_t>
struct __align__(8) vec4_t {
scalar_t x;
scalar_t y;
scalar_t z;
scalar_t w;
};
typedef struct __align__(4) {
c10::Float8_e4m3fn x;
c10::Float8_e4m3fn y;
c10::Float8_e4m3fn z;
c10::Float8_e4m3fn w;
}
float8x4_t;
template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
const scalar_t* __restrict__ input,
const float* __restrict__ scale,
int64_t num_elems) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
while (i < num_elems) {
out[i] = scaled_fp8_conversion(input[i], *scale);
i += blockDim.x * gridDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
// Invert the scale so that we can use multiplications to avoid expensive
// division.
const float inverted_scale = 1.0f / (*scale);
// Vectorized input/output to better utilize memory bandwidth.
const vec4_t<scalar_t>* vectorized_in =
reinterpret_cast<const vec4_t<scalar_t>*>(input);
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
int num_vec_elems = num_elems >> 2;
#pragma unroll 4
for (int i = tid; i < num_vec_elems; i += blockDim.x * gridDim.x) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;
out_vec.x = scaled_fp8_conversion(in_vec.x, inverted_scale);
out_vec.y = scaled_fp8_conversion(in_vec.y, inverted_scale);
out_vec.z = scaled_fp8_conversion(in_vec.z, inverted_scale);
out_vec.w = scaled_fp8_conversion(in_vec.w, inverted_scale);
vectorized_out[i] = out_vec;
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int i = num_vec_elems * 4 + tid; i < num_elems;
i += blockDim.x * gridDim.x) {
out[i] = scaled_fp8_conversion(input[i], inverted_scale);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment