Commit 9e03fc67 authored by raojy's avatar raojy
Browse files

nemotron_enable

parent 3b50924c
Pipeline #3457 failed with stages
in 0 seconds
...@@ -12,9 +12,12 @@ struct Utils { ...@@ -12,9 +12,12 @@ struct Utils {
if (!is_cached) { if (!is_cached) {
int device_id; int device_id;
cudaDeviceProp deviceProp; // cudaDeviceProp deviceProp;
cudaGetDevice(&device_id); // cudaGetDevice(&device_id);
cudaGetDeviceProperties(&deviceProp, device_id); // cudaGetDeviceProperties(&deviceProp, device_id);
hipDeviceProp_t deviceProp;
hipGetDevice(&device_id);
hipGetDeviceProperties(&deviceProp, device_id);
result = deviceProp.warpSize; result = deviceProp.warpSize;
is_cached = true; is_cached = true;
......
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project // SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#include <cuda_bf16.h>
#include <cuda_fp16.h>
// --- 只需 Include,不要任何宏定义 ---
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#endif
#pragma once #pragma once
#include <c10/util/BFloat16.h> #include <c10/util/BFloat16.h>
......
...@@ -293,7 +293,7 @@ PyObject* create_tuple_from_c_mixed(unsigned long long a, unsigned long long b, ...@@ -293,7 +293,7 @@ PyObject* create_tuple_from_c_mixed(unsigned long long a, unsigned long long b,
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Our exported C functions that call Python: // Our exported C functions that call Python:
// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h // use CUstream instead of cudaStream_t, to avoid including hip/hip_runtime_api.h
void* my_malloc(ssize_t size, int device, CUstream stream) { void* my_malloc(ssize_t size, int device, CUstream stream) {
ensure_context(device); ensure_context(device);
...@@ -424,7 +424,7 @@ void* my_malloc(ssize_t size, int device, CUstream stream) { ...@@ -424,7 +424,7 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
return (void*)d_mem; return (void*)d_mem;
} }
// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h // use CUstream instead of cudaStream_t, to avoid including hip/hip_runtime_api.h
void my_free(void* ptr, ssize_t size, int device, CUstream stream) { void my_free(void* ptr, ssize_t size, int device, CUstream stream) {
// get memory handle from the pointer // get memory handle from the pointer
if (!g_python_free_callback) { if (!g_python_free_callback) {
......
...@@ -104,6 +104,6 @@ CUresult cuMemUnmap(CUdeviceptr ptr, size_t size) { ...@@ -104,6 +104,6 @@ CUresult cuMemUnmap(CUdeviceptr ptr, size_t size) {
//////////////////////////////////////// ////////////////////////////////////////
// Import CUDA headers for NVIDIA GPUs // Import CUDA headers for NVIDIA GPUs
//////////////////////////////////////// ////////////////////////////////////////
#include <cuda_runtime_api.h> #include <hip/hip_runtime_api.h>
#include <cuda.h> #include <cuda.h>
#endif #endif
...@@ -38,15 +38,15 @@ ...@@ -38,15 +38,15 @@
#ifdef USE_ROCM #ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL #define FINAL_MASK 0xffffffffffffffffULL
#if defined(HIP_VERSION) && HIP_VERSION < 70000000 // #if defined(HIP_VERSION) && HIP_VERSION < 70000000
// On ROCm versions before 7.0, __syncwarp isn't defined. The below // // On ROCm versions before 7.0, __syncwarp isn't defined. The below
// implementation is copy/pasted from the implementation in ROCm 7.0 // // implementation is copy/pasted from the implementation in ROCm 7.0
__device__ inline void __syncwarp() { // __device__ inline void __syncwarp() {
__builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront"); // __builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront");
__builtin_amdgcn_wave_barrier(); // __builtin_amdgcn_wave_barrier();
__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront"); // __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront");
} // }
#endif // #endif
#else #else
#define FINAL_MASK 0xffffffff #define FINAL_MASK 0xffffffff
#endif #endif
......
#pragma once #pragma once
#include <cuda_runtime_api.h> #include <hip/hip_runtime_api.h>
#include <algorithm> #include <algorithm>
// maximum blocks per SM cap // maximum blocks per SM cap
......
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <cmath>
#include "cuda_compat.h"
#include "../dispatch_utils.h"
namespace vllm {
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first>
__device__ __forceinline__ scalar_t compute(const scalar_t& x,
const scalar_t& y) {
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
}
// Activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first>
__global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
}
}
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC,
bool act_first>
__global__ void act_and_mul_kernel_opt1(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
const int64_t token_idx= blockIdx.x;
int idx = threadIdx.x * VEC;
if (idx < d) {
const int64_t x_index = token_idx * 2 * d + idx;
const int64_t y_index = token_idx * d + idx;
VecType* x1 = (VecType*)(input + x_index);
VecType* x2 = (VecType*)(input + x_index + d);
VecType* y = (VecType*)(out + y_index);
scalar_t r_x1[VEC];
scalar_t r_x2[VEC];
scalar_t r_y[VEC];
*(VecType*)r_x1 = *x1;
*(VecType*)r_x2 = *x2;
#pragma unroll
for (int i = 0; i < VEC; i++) {
r_y[i] = ACT_FN(r_x1[i]) * r_x2[i];
}
*y = *(VecType*)r_y;
}
}
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC,
bool act_first>
__global__ void act_and_mul_kernel_opt2(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
const int64_t token_idx = blockIdx.x;
int idx = threadIdx.x * VEC;
for (; idx < d; idx += blockDim.x * VEC) {
const int64_t x_index = token_idx * 2 * d + idx;
const int64_t y_index = token_idx * d + idx;
VecType* x1 = (VecType*)(input + x_index);
VecType* x2 = (VecType*)(input + x_index + d);
VecType* y = (VecType*)(out + y_index);
scalar_t r_x1[VEC];
scalar_t r_x2[VEC];
scalar_t r_y[VEC];
*(VecType*)r_x1 = *x1;
*(VecType*)r_x2 = *x2;
#pragma unroll
for (int i = 0; i < VEC; i++) {
r_y[i] = ACT_FN(r_x1[i]) * r_x2[i];
}
*y = *(VecType*)r_y;
}
}
template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
return (T)(((float)x) / (1.0f + expf((float)-x)));
}
template <typename T>
__device__ __forceinline__ T gelu_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'none' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
const float f = (float)x;
constexpr float ALPHA = M_SQRT1_2;
return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
}
template <typename T>
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'tanh' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
const float f = (float)x;
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
constexpr float KAPPA = 0.044715;
float x_cube = f * f * f;
float inner = BETA * (f + KAPPA * x_cube);
return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
}
} // namespace vllm
// Launch activation and gating kernel.
// Use ACT_FIRST (bool) indicating whether to apply the activation function
// first.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
if (num_tokens == 0) { \
return; \
} \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
if (0 == d % 8 && d <= 16384) { \
if (d <= 512) { \
vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 2, ACT_FIRST> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 1024) { \
vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 8, ACT_FIRST> \
<<<grid, 128, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 2048) { \
vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 8, ACT_FIRST> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 4096) { \
vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 8, ACT_FIRST> \
<<<grid, 512, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else { \
vllm::act_and_mul_kernel_opt2<scalar_t, KERNEL<scalar_t>, 8, ACT_FIRST> \
<<<grid, 1024, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
} else { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, ACT_FIRST> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
});
void silu_and_mul_opt(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true);
}
// void mul_and_silu_opt(torch::Tensor& out, // [..., d]
// torch::Tensor& input) // [..., 2 * d]
// {
// // The difference between mul_and_silu and silu_and_mul is that mul_and_silu
// // applies the silu to the latter half of the input.
// LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false);
// }
void gelu_and_mul_opt(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true);
}
void gelu_tanh_and_mul_opt(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true);
}
\ No newline at end of file
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
namespace vllm {
template <typename T>
__global__ void trans_w16_gemm_cudakernel(int64_t num_kernels,T* dst,const T* src,int64_t row,int64_t col)
{
int64_t id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
int64_t j=id%row;
int64_t i=id/row;
dst[i*row+j]=src[j*col+i];
}
void trans_w16_gemm_cuda(half* dst,const half* src,int64_t row,int64_t col){
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int64_t num_kernels=row*col;
int block_size=256;
trans_w16_gemm_cudakernel<<<(num_kernels+block_size-1)/block_size,block_size, 0, stream>>>(num_kernels,dst,src,row,col);
}
} // namespace vllm
void trans_w16_gemm(torch::Tensor dst,torch::Tensor src,int64_t row,int64_t col){
const at::cuda::OptionalCUDAGuard device_guard(device_of(src));
vllm::trans_w16_gemm_cuda(
(half*)dst.data_ptr(),
(const half*)src.data_ptr(),
row,
col
);
}
\ No newline at end of file
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <torch/all.h> #include <torch/all.h>
#include <cuda_runtime_api.h> #include <hip/hip_runtime_api.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <torch/all.h> #include <torch/all.h>
#include <cuda_runtime_api.h> #include <hip/hip_runtime_api.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <torch/all.h> #include <torch/all.h>
#include <cuda_runtime_api.h> #include <hip/hip_runtime_api.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
......
...@@ -46,15 +46,15 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) { ...@@ -46,15 +46,15 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
#if defined(__CUDA_ARCH__) || defined(USE_ROCM) #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) #if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { // __device__ __forceinline__ void atomicAdd(half* address, half val) {
atomicAdd_half(address, val); // atomicAdd_half(address, val);
} // }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) // #if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { // __device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
atomicAdd_half2(address, val); // atomicAdd_half2(address, val);
} // }
#endif // #endif
#endif #endif
#endif #endif
......
Contains code from https://github.com/IST-DASLab/Sparse-Marlin/
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
\ No newline at end of file
/*
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
* Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
namespace marlin_24 {
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template <typename T, int n>
struct Vec {
T elems[n];
__device__ T& operator[](int i) { return elems[i]; }
};
template <int M_, int N_, int K_>
struct ShapeBase {
static constexpr int M = M_, N = N_, K = K_;
};
using I4 = Vec<int, 4>;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>;
using FragM = Vec<uint, 1>;
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; // quantization scales
} // namespace marlin_24
/*
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
* Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "base.h"
namespace marlin_24 {
// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__ inline void cp_async4_pred_zfill(void* smem_ptr,
const void* glob_ptr,
bool pred = true,
const bool zfill = false) {
const int BYTES = 16;
int src_in_bytes = (zfill ? 0 : BYTES);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
}
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
// Asynchronous global->shared copy
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
}
// Async copy fence.
__device__ inline void cp_async_fence() {
asm volatile("cp.async.commit_group;\n" ::);
}
// Wait until at most `n` async copy stages are still pending.
template <int n>
__device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
: "r"(smem));
}
__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
: "=r"(a[0]), "=r"(a[1])
: "r"(smem));
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
: "r"(smem));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) {
int state = -1;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
: "=r"(state)
: "l"(lock));
while (state != count);
}
__syncthreads();
}
// Release barrier and increment visitation count.
__device__ inline void barrier_release(int* lock, bool reset = false) {
__syncthreads();
if (threadIdx.x == 0) {
if (reset) {
lock[0] = 0;
return;
}
int val = 1;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
:
: "l"(lock), "r"(val));
}
}
} // namespace marlin_24
/*
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
* Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "base.h"
#include <cudaTypedefs.h>
namespace marlin_24 {
// On CUDA earlier than 12.5, the ordered_metadata version of this instruction
// is not supported. On later versions of CUDA the version without ordered
// metadata results in the following warning:
// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction
// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially
// | reduced performance on some future architectures
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
#define MMA_SP_INST \
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
#else
#define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
#endif
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
const FragA& frag_b, FragC& frag_c, FragM& frag_m,
const int psel) {
const uint32_t* a0 = reinterpret_cast<const uint32_t*>(&a_frag0);
const uint32_t* a1 = reinterpret_cast<const uint32_t*>(&a_frag1);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
float* c = reinterpret_cast<float*>(&frag_c);
if (psel == 0) {
asm volatile(MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x0;\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
"f"(c[2]), "f"(c[3]), "r"(e[0]));
asm volatile(MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x0;\n"
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
"f"(c[6]), "f"(c[7]), "r"(e[0]));
} else {
asm volatile(MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x1;\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
"f"(c[2]), "f"(c[3]), "r"(e[0]));
asm volatile(MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x1;\n"
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
"f"(c[6]), "f"(c[7]), "r"(e[0]));
}
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
__device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2,
float c3) {
uint2 r;
asm("{\n\t"
".reg .f16 a, b, c, d; \n\t"
"cvt.rn.f16.f32 a, %2; \n\t"
"cvt.rn.f16.f32 b, %3; \n\t"
"cvt.rn.f16.f32 c, %4; \n\t"
"cvt.rn.f16.f32 d, %5; \n\t"
"mov.b32 %0, {a, b}; \n\t"
"mov.b32 %1, {c, d}; \n\t"
"}"
: "=r"(r.x), "=r"(r.y)
: "f"(c0), "f"(c1), "f"(c2), "f"(c3));
return r;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
: "=r"(res)
: "r"(a), "n"(start_byte), "n"(mask));
return res;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__ inline FragB dequant_4bit(int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__ inline FragB dequant_8bit(int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
frag_b[0] = __hmul2(frag_b[0], s);
frag_b[1] = __hmul2(frag_b[1], s);
}
__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
FragS& s0, float* c4, float* c5, float* c6,
float* c7, FragS& s1) {
*c0 = __fmul_rn(*c0, __half2float(s0[0].x));
*c1 = __fmul_rn(*c1, __half2float(s0[0].y));
*c2 = __fmul_rn(*c2, __half2float(s0[1].x));
*c3 = __fmul_rn(*c3, __half2float(s0[1].y));
*c4 = __fmul_rn(*c4, __half2float(s1[0].x));
*c5 = __fmul_rn(*c5, __half2float(s1[0].y));
*c6 = __fmul_rn(*c6, __half2float(s1[1].x));
*c7 = __fmul_rn(*c7, __half2float(s1[1].y));
}
} // namespace marlin_24
/*
* Notice: This file was modified by Neuralmagic inc to include 8-bit support
*
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
* Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include "common/base.h"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#else
#include "common/mem.h"
#include "common/mma.h"
#endif
template <typename T>
inline std::string str(T x) {
return std::to_string(x);
}
namespace marlin_24 {
// 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles.
static constexpr int THREADS = 256;
static constexpr int STAGES = 4;
static constexpr int min_thread_n = 128;
static constexpr int tile_size = 16;
static constexpr int max_par = 64;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <const int num_bits, // weight bits
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared
// fetch pipeline
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__ void Marlin_24(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
const int4* __restrict__ meta, // 2bit metadata information about 2:4
// format on B
int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4* __restrict__ s, // fp16 quantization scales of shape
// (k/groupsize)xn
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization
) {}
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
torch::Tensor& workspace,
vllm::ScalarTypeId const b_q_type_id,
int64_t size_m, int64_t size_n,
int64_t size_k) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "gptq_marlin_24_gemm(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
template <const int num_bits, // weight bits
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared
// fetch pipeline
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__ void Marlin_24(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
const int4* __restrict__ meta, // 2bit metadata information about 2:4
// format on B
int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4* __restrict__ s, // fp16 quantization scales of shape
// (k/groupsize)xn
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization
) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
// example:
// 0 1 3
// 0 2 3
// 1 2 4
// While this kind of partitioning makes things somewhat more complicated, it
// ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
int parallel = 1;
if (prob_m > 16 * thread_m_blocks) {
parallel = prob_m / (16 * thread_m_blocks);
prob_m = 16 * thread_m_blocks;
}
// number of thread_k_blocks in k-dim
int k_tiles = prob_k / 32 / thread_k_blocks;
// number of thread_n_blocks in n-dim
int n_tiles = prob_n / 16 / thread_n_blocks;
// iters needed to cover all slices
int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
// Ensure that the number of tiles in each stripe is a multiple of the
// groupsize; this avoids an annoying special case where a stripe starts in
// the middle of group.
if (group_blocks != -1)
iters = (group_blocks / thread_k_blocks) *
ceildiv(iters, (group_blocks / thread_k_blocks));
int slice_row = (iters * blockIdx.x) % k_tiles;
int slice_col_par = (iters * blockIdx.x) / k_tiles;
int slice_col = slice_col_par;
// number of threadblock tiles in the current slice
int slice_iters;
// total number of active threadblocks in the current slice
int slice_count = 0;
// index of threadblock in current slice; numbered bottom to top
int slice_idx;
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if (slice_col_par >= n_tiles) {
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
locks += (slice_col_par / n_tiles) * n_tiles;
slice_col = slice_col_par % n_tiles;
}
// Compute all information about the current slice which is required for
// synchronization.
auto init_slice = [&]() {
slice_iters =
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
if (slice_iters == 0) return;
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
slice_count = 1;
slice_idx = 0;
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
if (col_first <= k_tiles * (slice_col_par + 1)) {
int col_off = col_first - k_tiles * slice_col_par;
slice_count = ceildiv(k_tiles - col_off, iters);
if (col_off > 0) slice_count++;
int delta_first = iters * blockIdx.x - col_first;
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
slice_idx = slice_count - 1;
else {
slice_idx = slice_count - 1 - delta_first / iters;
if (col_off > 0) slice_idx--;
}
}
if (slice_col == n_tiles) {
A += 16 * thread_m_blocks * prob_k / 8;
C += 16 * thread_m_blocks * prob_n / 8;
locks += n_tiles;
slice_col = 0;
}
};
init_slice();
// RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements
int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
// stride of an A matrix tile in shared memory
constexpr int a_sh_stride = 32 * thread_k_blocks / 8;
// delta between subsequent A tiles in global memory
constexpr int a_gl_rd_delta_o = 32 * thread_k_blocks / 8;
// between subsequent accesses within a tile
int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
// between shared memory writes
constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
// between shared memory tile reads //RLC: 2 * #warps k-dim
constexpr int a_sh_rd_delta_o = 4 * ((threads / 32) / (thread_n_blocks / 4));
// within a shared memory tile
constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
// overall size of a tile
constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
// number of shared write iterations for a tile
constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta);
constexpr int pack_factor = 32 / num_bits;
int b_gl_stride = 16 * prob_n / (pack_factor * 4);
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
constexpr int b_sh_wr_delta = threads * b_thread_vecs;
constexpr int b_sh_rd_delta = threads * b_thread_vecs;
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16
constexpr int m_sh_stride =
(16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp
int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks;
int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride);
constexpr int m_sh_wr_delta = threads / 2;
constexpr int m_sh_rd_delta = threads / 2;
constexpr int m_sh_stage = m_sh_stride * thread_k_blocks;
constexpr int m_sh_iters = ceildiv(m_sh_stage, m_sh_wr_delta);
int s_gl_stride = prob_n / 8;
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_sh_stage = s_sh_stride;
int s_gl_rd_delta = s_gl_stride;
// Global A read index of current thread.
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
a_gl_rd += a_gl_rd_delta_o * slice_row;
// Shared write index of current thread.
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
// Shared read index.
int a_sh_rd =
a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
a_sh_rd += 4 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
auto b_sh_wr = threadIdx.x * b_thread_vecs;
auto b_sh_rd = threadIdx.x * b_thread_vecs;
int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) +
(threadIdx.x % (m_sh_stride));
m_gl_rd += (m_sh_stride)*slice_col;
m_gl_rd += m_gl_rd_delta_o * slice_row;
auto m_sh_wr = threadIdx.x;
auto m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16;
int s_gl_rd;
if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
s_sh_stride * slice_col + threadIdx.x;
}
auto s_sh_wr = threadIdx.x;
int s_sh_rd;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4; // Note that in the original Marlin kernel
// this is (threadIdx.x % 32) / 4
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
bool a_sh_wr_pred[a_sh_wr_iters];
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) {
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
}
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
// To ensure that writing and reading A tiles to/from shared memory, the
// latter in fragment format, is fully bank conflict free, we need to use a
// rather fancy XOR-based layout. The key here is that neither reads nor
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
// same shared memory banks. Further, it seems (based on NSight-Compute) that
// each warp must also write a consecutive memory segment?
auto transform_a = [&](int i) {
int row = i / a_gl_rd_delta_o;
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes.
int a_sh_wr_trans[a_sh_wr_iters];
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++)
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks];
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < thread_m_blocks; j++) {
a_sh_rd_trans[0][i][j] =
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
a_sh_rd_trans[1][i][j] =
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd + 2);
}
}
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const int4* B_ptr[b_sh_wr_iters];
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta;
const int4* meta_ptr[m_sh_iters];
#pragma unroll
for (int i = 0; i < m_sh_iters; i++)
meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd;
extern __shared__ int4 sh[];
// Shared memory storage for global fetch pipelines.
int4* sh_a = sh;
int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_s = sh_b + (stages * b_sh_stage);
int4* sh_m = sh_s + (stages * s_sh_stage);
// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks][2];
I4 frag_b_quant[2][b_thread_vecs];
FragM frag_m[2][2];
FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4];
// Zero accumulators.
auto zero_accums = [&]() {
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
reinterpret_cast<float*>(frag_c)[i] = 0;
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
if (pred) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) {
cp_async4_pred(
&sh_a_stage[a_sh_wr_trans[i]],
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
a_sh_wr_pred[i]);
}
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < b_thread_vecs; j++) {
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
}
B_ptr[i] += b_gl_rd_delta_o;
}
int4* sh_meta_stage = sh_m + m_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < m_sh_iters; i++) {
if (m_sh_wr_pred)
cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]);
meta_ptr[i] += m_gl_rd_delta_o;
}
// Only fetch scales if this tile starts a new group
if constexpr (group_blocks != -1) {
// This assumes group_blocks >= thread_k_blocks
// and would need to be modified to support smaller groups.
static_assert(group_blocks >= thread_k_blocks);
if (pipe % (group_blocks / thread_k_blocks) == 0) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence();
};
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait<stages - 2>();
__syncthreads();
};
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto fetch_to_registers = [&](int k, int pipe) {
// It may seem inefficient that we reload the groups for every sub-tile;
// however, this does not seem to be a significant bottleneck, while some
// theoretically better attempts have lead to bad instruction ordering by
// the compiler and correspondingly a noticeable drop in performance.
if constexpr (group_blocks != -1) {
// This assumes group_blocks >= thread_k_blocks
// and would need to be modified to support smaller groups.
static_assert(group_blocks >= thread_k_blocks);
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
}
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
ldsm4(frag_a[k % 2][i][0],
&sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]);
ldsm4(frag_a[k % 2][i][1],
&sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]);
}
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_thread_vecs; i++) {
frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
}
// Load meta with ldsm4
int4* sh_m_stage = sh_m + m_sh_stage * pipe;
ldsm4_m(frag_m[k % 2][0],
&sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]);
};
// Execute the actual tensor core matmul of a sub-tile.
auto matmul = [&](int k) {
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
for (int j = 0; j < 4; j++) {
FragB frag_b0;
FragB frag_b1;
if constexpr (num_bits == 4) {
int b_quant = frag_b_quant[k % 2][0][j];
int b_quant_shift = b_quant >> 8;
frag_b0 = dequant_4bit(b_quant);
frag_b1 = dequant_4bit(b_quant_shift);
} else {
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
frag_b0 = dequant_8bit(b_quant_0);
frag_b1 = dequant_8bit(b_quant_1);
}
// If there are no groups, we can just scale the final output once and can
// avoid doing so for each weight.
if constexpr (group_blocks != -1) {
scale(frag_b0, frag_s[k % 2][j], 0);
}
if constexpr (group_blocks != -1) {
scale(frag_b1, frag_s[k % 2][j], 1);
}
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0],
frag_m[k % 2][j / 2], j % 2);
}
}
};
// Since we slice across the k dimension of a tile in order to increase the
// number of warps while keeping the n dimension of a tile reasonable, we have
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride_threads / 2;
if (red_off >= 1) {
auto red_idx = threadIdx.x / b_sh_stride_threads;
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
constexpr int red_sh_delta = b_sh_stride_threads;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
(threadIdx.x % b_sh_stride_threads);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0.
#pragma unroll
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
#pragma unroll
for (int i = red_off; i > 0; i /= 2) {
if (i <= red_idx && red_idx < 2 * i) {
#pragma unroll
for (int j = 0; j < 4 * 2; j++) {
int red_sh_wr =
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) {
float* c_rd =
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
#pragma unroll
for (int k = 0; k < 4; k++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
c_rd[k] + c_wr[k];
}
sh[red_sh_wr] =
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
}
}
__syncthreads();
}
if (red_idx == 0) {
#pragma unroll
for (int i = 0; i < 4 * 2; i++) {
float* c_rd =
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
#pragma unroll
for (int j = 0; j < 4; j++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
c_rd[j];
}
}
__syncthreads();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto global_reduce = [&](bool first = false, bool last = false) {
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
constexpr int active_threads = 32 * thread_n_blocks / 4;
if (threadIdx.x < active_threads) {
int c_gl_stride = prob_n / 8;
int c_gl_wr_delta_o = 2 * 4 * c_gl_stride;
int c_gl_wr_delta_i =
c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28)
int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) +
8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4;
c_gl_wr += (2 * thread_n_blocks) * slice_col;
constexpr int c_sh_wr_delta = active_threads;
auto c_sh_wr = threadIdx.x;
int col = 2 * ((threadIdx.x % 32) % 4);
if (!first) {
// Interestingly, doing direct global accesses here really seems to mess up
// the compiler and lead to slowdowns, hence we also use async-copies even
// though these fetches are not actually asynchronous.
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) {
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
c_gl_wr_delta_i * (i % 2)],
i < (thread_m_blocks - 1) * 4 ||
8 * (i / 2) + col + (i % 2) < prob_m);
}
cp_async_fence();
cp_async_wait<0>();
}
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) {
if (i < (thread_m_blocks - 1) * 4 ||
8 * (i / 2) + col + (i % 2) < prob_m) {
if (!first) {
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
#pragma unroll
for (int j2 = 0; j2 < 2; j2++) {
#pragma unroll
for (int j1 = 0; j1 < 4; j1++) {
reinterpret_cast<float*>(
&frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 +
4 * ((i % 4) / 2) + i % 2] +=
__half2float(
reinterpret_cast<__half*>(&c_red)[(j2 * 4 + j1)]);
}
}
}
if (!last) {
int4 c;
#pragma unroll
for (int j2 = 0; j2 < 2; j2++) {
#pragma unroll
for (int j1 = 0; j1 < 4; j1++) {
reinterpret_cast<__half*>(&c)[(j2 * 4 + j1)] =
__float2half(reinterpret_cast<float*>(
&frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 +
4 * ((i % 4) / 2) + i % 2]);
}
}
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
c;
}
}
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto write_result = [&]() {
int c_gl_stride = prob_n / 8;
constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC:
constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC:
constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC:
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
(threadIdx.x % (2 * thread_n_blocks));
c_gl_wr += (2 * thread_n_blocks) * slice_col;
int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) +
((threadIdx.x % 32) / 4); // RLC:
c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4)
constexpr int c_sh_rd_delta =
c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC:
int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) +
(threadIdx.x % (2 * 2 * thread_n_blocks));
int c_gl_wr_end = c_gl_stride * prob_m;
auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS& s0,
float c4, float c5, float c6, float c7, FragS& s1) {
uint2 res[2];
res[0] = to_half4(c0, c1, c2, c3);
res[1] = to_half4(c4, c5, c6, c7);
half2* tmp = (half2*)&res;
// for per-column quantization we finally apply the scale here
if constexpr (group_blocks == -1 && num_bits == 4) {
tmp[0] = __hmul2(tmp[0], s0[0]);
tmp[1] = __hmul2(tmp[1], s0[1]);
tmp[2] = __hmul2(tmp[2], s1[0]);
tmp[3] = __hmul2(tmp[3], s1[1]);
}
((int4*)sh)[idx] = *((int4*)&res[0]);
};
// RLC: only warp 0 and 1 baseline example
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
int wr = c_sh_wr;
write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0],
frag_c[i][3][0][0], frag_s[0][0], frag_c[i][0][0][2],
frag_c[i][1][0][2], frag_c[i][2][0][2], frag_c[i][3][0][2],
frag_s[0][2]);
write(wr + c_sh_stride, frag_c[i][0][0][1], frag_c[i][1][0][1],
frag_c[i][2][0][1], frag_c[i][3][0][1], frag_s[0][0],
frag_c[i][0][0][3], frag_c[i][1][0][3], frag_c[i][2][0][3],
frag_c[i][3][0][3], frag_s[0][2]);
write(wr + 4 * c_sh_stride_2, frag_c[i][0][1][0], frag_c[i][1][1][0],
frag_c[i][2][1][0], frag_c[i][3][1][0], frag_s[0][0],
frag_c[i][0][1][2], frag_c[i][1][1][2], frag_c[i][2][1][2],
frag_c[i][3][1][2], frag_s[0][2]);
write(wr + 4 * c_sh_stride_2 + c_sh_stride, frag_c[i][0][1][1],
frag_c[i][1][1][1], frag_c[i][2][1][1], frag_c[i][3][1][1],
frag_s[0][0], frag_c[i][0][1][3], frag_c[i][1][1][3],
frag_c[i][2][1][3], frag_c[i][3][1][3], frag_s[0][2]);
c_sh_wr += 8 * c_sh_stride_2;
}
}
__syncthreads();
#pragma unroll
for (int i = 0;
i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
i++) {
if (c_gl_wr < c_gl_wr_end) {
C[c_gl_wr] = sh[c_sh_rd];
c_gl_wr += c_gl_wr_delta;
c_sh_rd += c_sh_rd_delta;
}
}
};
// Start global fetch and register load pipelines.
auto start_pipes = [&]() {
#pragma unroll
for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
zero_accums();
wait_for_stage();
fetch_to_registers(0, 0);
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
};
start_pipes();
// Main loop.
while (slice_iters) {
// We unroll over both the global fetch and the register load pipeline to
// ensure all shared memory accesses are static. Note that both pipelines have
// even length meaning that the next iteration will always start at index 0.
#pragma unroll
for (int pipe = 0; pipe < stages;) {
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages);
matmul(pipe);
wait_for_stage();
fetch_to_registers(pipe + 1, (pipe + 1) % stages);
pipe++;
slice_iters--;
if (slice_iters == 0) break;
}
a_gl_rd += a_gl_rd_delta_o * stages;
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if (slice_iters == 0) {
cp_async_wait<0>();
bool last = slice_idx == slice_count - 1;
// For per-column scales, we only fetch them here in the final step before
// write-out
if constexpr (group_blocks == -1) {
if constexpr (num_bits == 8) {
if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async_fence();
} else {
if (last) {
if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async_fence();
}
}
}
thread_block_reduce();
if constexpr (group_blocks == -1) {
if constexpr (num_bits == 8) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
*(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]);
}
} else {
if (last) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
*(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]);
}
}
}
}
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if constexpr (group_blocks == -1 && num_bits == 8) {
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0],
&frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0],
&frag_c[i][0][0][2], &frag_c[i][1][0][2],
&frag_c[i][2][0][2], &frag_c[i][3][0][2],
frag_s[0][2]);
scale_floats(&frag_c[i][0][0][1], &frag_c[i][1][0][1],
&frag_c[i][2][0][1], &frag_c[i][3][0][1], frag_s[0][0],
&frag_c[i][0][0][3], &frag_c[i][1][0][3],
&frag_c[i][2][0][3], &frag_c[i][3][0][3],
frag_s[0][2]);
scale_floats(&frag_c[i][0][1][0], &frag_c[i][1][1][0],
&frag_c[i][2][1][0], &frag_c[i][3][1][0], frag_s[0][0],
&frag_c[i][0][1][2], &frag_c[i][1][1][2],
&frag_c[i][2][1][2], &frag_c[i][3][1][2],
frag_s[0][2]);
scale_floats(&frag_c[i][0][1][1], &frag_c[i][1][1][1],
&frag_c[i][2][1][1], &frag_c[i][3][1][1], frag_s[0][0],
&frag_c[i][0][1][3], &frag_c[i][1][1][3],
&frag_c[i][2][1][3], &frag_c[i][3][1][3],
frag_s[0][2]);
}
}
}
if (slice_count > 1) { // only globally reduce if there is more than one
// block in a slice
barrier_acquire(&locks[slice_col], slice_idx);
global_reduce(slice_idx == 0, last);
barrier_release(&locks[slice_col], last);
}
if (last) // only the last block in a slice actually writes the result
write_result();
slice_row = 0;
slice_col_par++;
slice_col++;
init_slice();
if (slice_iters) {
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
#pragma unroll
for (int i = 0; i < m_sh_iters; i++)
meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles;
if (slice_col == 0) {
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
#pragma unroll
for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride;
}
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
start_pipes();
}
}
}
}
#endif
#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, GROUP_BLOCKS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
group_blocks == GROUP_BLOCKS) { \
cudaFuncSetAttribute( \
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS> \
<<<blocks, THREADS, max_shared_mem, stream>>>(A_ptr, B_ptr, meta_ptr, \
C_ptr, s_ptr, prob_n, \
prob_m, prob_k, locks); \
}
void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
void* s, int prob_m, int prob_n, int prob_k,
void* workspace, int num_bits, int groupsize = -1,
int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
int thread_m = -1, int sms = -1, int max_par = 16) {
int tot_n = prob_n;
int tot_n_blocks = ceildiv(tot_n, 16);
int pad = 16 * tot_n_blocks - tot_n;
if (sms == -1) {
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
}
TORCH_CHECK(sms > 0);
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
if (thread_k == -1 || thread_m == -1) {
if (prob_n <= 16) {
// For small batchizes, better partitioningif is slightly more important
// than better compute utilization
thread_k = 128;
thread_m = 128;
} else {
thread_k = 64;
thread_m = 256;
}
// Also had
// if prob_n > 256
// thread_k = 32;
// thread_m = 512;
// but this is broken,
// TODO(Lucas, Alex M): figure out why
}
int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction
int thread_m_blocks = thread_m / 16;
int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
int blocks = sms;
TORCH_CHECK(prob_m % thread_m == 0, "prob_m = ", prob_m,
" is not divisible by thread_m = ", thread_m);
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
" is not divisible by thread_k = ", thread_k);
if (group_blocks != -1) {
TORCH_CHECK((prob_k / 2) % group_blocks == 0, "prob_k/2 = ", prob_k / 2,
" is not divisible by group_blocks = ", group_blocks);
}
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]");
const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B;
const int4* meta_ptr = (const int4*)meta;
int4* C_ptr = (int4*)C;
const int4* s_ptr = (const int4*)s;
constexpr int max_m_blocks = 4;
int* locks = (int*)workspace;
for (int i = 0; i < tot_n_blocks; i += max_m_blocks) {
int thread_n_blocks = tot_n_blocks - i;
prob_n = tot_n - 16 * i;
int par = 1;
if (thread_n_blocks > max_m_blocks) {
// Note that parallel > 1 currently only works for inputs without any
// padding
par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16);
if (par > max_par) par = max_par;
prob_n = (max_m_blocks * 16) * par;
i += max_m_blocks * (par - 1);
thread_n_blocks = max_m_blocks;
}
// For compilation speed, we only define the kernel configurations that have
// seemed useful (in terms of performance) in our testing, however many more
// are, in principle, possible.
// the false is start of the CALL_IF macros
if (false) {
} // BMxBNxBK, group
// 4-bit
CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128
CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64
CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64
CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64
CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64
CALL_IF_2_4(4, 16, 2, 2, 4)
CALL_IF_2_4(4, 16, 3, 2, -1)
CALL_IF_2_4(4, 16, 3, 2, 4)
CALL_IF_2_4(4, 16, 4, 2, -1)
CALL_IF_2_4(4, 16, 4, 2, 4)
CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64
CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64
CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64
CALL_IF_2_4(4, 32, 2, 1, 4)
CALL_IF_2_4(4, 32, 3, 1, -1)
CALL_IF_2_4(4, 32, 3, 1, 4)
CALL_IF_2_4(4, 32, 4, 1, -1)
CALL_IF_2_4(4, 32, 4, 1, 4)
// 8-bit
CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128
CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64
CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64
CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64
CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64
CALL_IF_2_4(8, 16, 2, 2, 4)
CALL_IF_2_4(8, 16, 3, 2, -1)
CALL_IF_2_4(8, 16, 3, 2, 4)
CALL_IF_2_4(8, 16, 4, 2, -1)
CALL_IF_2_4(8, 16, 4, 2, 4)
CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64
CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64
CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64
CALL_IF_2_4(8, 32, 2, 1, 4)
CALL_IF_2_4(8, 32, 3, 1, -1)
CALL_IF_2_4(8, 32, 3, 1, 4)
CALL_IF_2_4(8, 32, 4, 1, -1)
CALL_IF_2_4(8, 32, 4, 1, 4)
else {
throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
", " + str(prob_k) + ", " + str(prob_n) + "]" +
", groupsize = " + str(groupsize) +
", thread_m_blocks = " + str(thread_m_blocks) +
", thread_n_blocks = " + str(thread_n_blocks) +
", thread_k_blocks = " + str(thread_k_blocks));
}
A_ptr += 16 * thread_n_blocks * (prob_k / 8) * par;
C_ptr += 16 * thread_n_blocks * (prob_m / 8) * par;
}
}
} // namespace marlin_24
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
torch::Tensor& workspace,
vllm::ScalarTypeId const b_q_type_id,
int64_t size_m, int64_t size_n,
int64_t size_k) {
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
// Verify num_bits
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
"num_bits must be uint4b8 or uint8b128. Got = ", b_q_type.str());
int pack_factor = 32 / b_q_type.size_bits();
// Verify M
TORCH_CHECK(size_m == a.size(0),
"Shape mismatch: a.size(0) = " + str(a.size(0)) +
", size_m = " + str(size_m));
// Verify K
TORCH_CHECK(size_k == a.size(1),
"Shape mismatch: a.size(1) = " + str(a.size(1)) +
", size_k = " + str(size_k));
TORCH_CHECK(size_k % marlin_24::tile_size == 0,
"size_k = " + str(size_k) + " is not divisible by tile_size = " +
str(marlin_24::tile_size));
TORCH_CHECK((size_k / marlin_24::tile_size / 2) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = " +
str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
", tile_size = " + str(marlin_24::tile_size));
// Verify N
TORCH_CHECK(b_scales.size(1) == size_n,
"b_scales.size(1) = " + str(b_scales.size(1)) +
", size_n = " + str(size_n));
TORCH_CHECK(
b_q_weight.size(1) % marlin_24::tile_size == 0,
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
" is not divisible by tile_size = " + str(marlin_24::tile_size));
int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor;
TORCH_CHECK(
size_n == actual_size_n,
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
// Verify meta
TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2,
"b_meta.size(0) = ", b_meta.size(0),
" is not size_k / 8 / 2 / 2 = ", size_k / 8 / 2 / 2);
TORCH_CHECK(b_meta.size(1) == size_n * 2, "b_meta.size(1) = ", b_meta.size(1),
" is not size_n * 2 = ", size_n * 2);
// Verify A device and strides
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
TORCH_CHECK(a.dtype() == torch::kFloat16,
"A is not float16, currently only float16 is supported");
// Verify B device and strides
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
// Verify b_meta device and strides
TORCH_CHECK(b_meta.device().is_cuda(), "b_meta is not on GPU");
TORCH_CHECK(b_meta.is_contiguous(), "b_meta is not contiguous");
// Verify scales device and strides
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
TORCH_CHECK(b_scales.dtype() == torch::kFloat16,
"A is not float16, currently only float16 is supported");
// Alloc C matrix
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c = torch::empty({size_m, size_n}, options);
int thread_k = -1;
int thread_m = -1;
int sms = -1;
int max_par = marlin_24::max_par;
int groupsize = -1;
if (b_scales.size(0) > 1) {
TORCH_CHECK(size_k % b_scales.size(0) == 0,
"size_k = " + str(size_k) +
", is not divisible by b_scales.size(0) = " +
str(b_scales.size(0)));
groupsize = size_k / b_scales.size(0);
groupsize /= 2; // Because of 24
}
// Verify groupsize
TORCH_CHECK(groupsize == -1 || groupsize == 64,
"Unexpected groupsize = " + str(groupsize));
// Verify workspace size
TORCH_CHECK(size_n % marlin_24::min_thread_n == 0,
"size_n = " + str(size_n) +
", is not divisible by min_thread_n = " +
str(marlin_24::min_thread_n));
int min_workspace_size =
(size_n / marlin_24::min_thread_n) * marlin_24::max_par;
TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = " + str(workspace.numel()) +
" is below min_workspace_size = " + str(min_workspace_size));
int dev = a.get_device();
marlin_24::marlin_cuda_2_4(
a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(),
b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(),
b_q_type.size_bits(), groupsize, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_m, sms, max_par);
return c;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
}
...@@ -4,7 +4,7 @@ ninja ...@@ -4,7 +4,7 @@ ninja
packaging>=24.2 packaging>=24.2
setuptools>=77.0.3,<81.0.0 setuptools>=77.0.3,<81.0.0
setuptools-scm>=8 setuptools-scm>=8
torch==2.10.0 #torch==2.10.0
wheel wheel
jinja2>=3.1.6 jinja2>=3.1.6
regex regex
......
...@@ -2,7 +2,7 @@ regex # Replace re for higher-performance regex matching ...@@ -2,7 +2,7 @@ regex # Replace re for higher-performance regex matching
cachetools cachetools
psutil psutil
sentencepiece # Required for LLaMA tokenizer. sentencepiece # Required for LLaMA tokenizer.
numpy #numpy
requests >= 2.26.0 requests >= 2.26.0
tqdm tqdm
blake3 blake3
...@@ -32,7 +32,7 @@ pyzmq >= 25.0.0 ...@@ -32,7 +32,7 @@ pyzmq >= 25.0.0
msgspec msgspec
gguf >= 0.17.0 gguf >= 0.17.0
mistral_common[image] >= 1.9.1 mistral_common[image] >= 1.9.1
opencv-python-headless >= 4.13.0 # required for video IO #opencv-python-headless >= 4.13.0 # required for video IO
pyyaml pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
setuptools>=77.0.3,<81.0.0; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 setuptools>=77.0.3,<81.0.0; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
......
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
-r common.txt -r common.txt
--extra-index-url https://download.pytorch.org/whl/rocm7.1 --extra-index-url https://download.pytorch.org/whl/rocm7.1
torch==2.10.0 #torch==2.10.0
torchvision==0.25.0 #torchvision==0.25.0
torchaudio==2.10.0 #torchaudio==2.10.0
triton==3.6.0 #triton==3.6.0
cmake>=3.26.1,<4 cmake>=3.26.1,<4
packaging>=24.2 packaging>=24.2
setuptools>=77.0.3,<80.0.0 setuptools>=77.0.3,<80.0.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment