"vllm/vscode:/vscode.git/clone" did not exist on "262d263f6c56fa95e15422d3a475da8efdf67cc1"
Commit 38d80967 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori

parents 33650733 880c741b
...@@ -52,15 +52,6 @@ ...@@ -52,15 +52,6 @@
#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
#define AT_DISPATCH_BYTE_CASE(enum_type, ...) \
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, byte_t, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_BYTE_TYPES(...) \
AT_DISPATCH_BYTE_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_BYTE_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
......
...@@ -140,6 +140,211 @@ fused_add_rms_norm_kernel( ...@@ -140,6 +140,211 @@ fused_add_rms_norm_kernel(
} }
} }
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck.
_f16VecPN struct extends _f16Vec to add operations specifically required for
polynomial normalization (poly norm).
The original _f16Vec does not include the sum-of-powers computation or
in-place polynomial normalization logic. */
template <typename scalar_t, int width>
struct alignas(16) _f16VecPN : _f16Vec<scalar_t, width> {
using Base = _f16Vec<scalar_t, width>;
using Converter = typename Base::Converter;
using T1 = typename Base::T1;
using T2 = typename Base::T2;
using Base::data;
__device__ auto sum_pows() const {
float s2 = 0.0f, s4 = 0.0f, s6 = 0.0f;
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
float x2 = z.x * z.x;
float x4 = x2 * x2;
float x6 = x4 * x2;
float y2 = z.y * z.y;
float y4 = y2 * y2;
float y6 = y4 * y2;
s2 += x2 + y2;
s4 += x4 + y4;
s6 += x6 + y6;
}
return std::make_tuple(s2, s4, s6);
}
__device__ void poly_norm_inplace(const float w2_inv_std,
const float w1_inv_std2,
const float w0_inv_std3, const float bias) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
float x2 = z.x * z.x;
float x3 = x2 * z.x;
z.x = w2_inv_std * z.x + w1_inv_std2 * x2 + w0_inv_std3 * x3 + bias;
float y2 = z.y * z.y;
float y3 = y2 * z.y;
z.y = w2_inv_std * z.y + w1_inv_std2 * y2 + w0_inv_std3 * y3 + bias;
auto out = Converter::convert(z);
data[i] = out.x;
data[i + 1] = out.y;
}
}
};
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [3]
const scalar_t* __restrict__ bias, // [1]
const float epsilon, const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16VecPN<scalar_t, width>>);
static_assert(sizeof(_f16VecPN<scalar_t, width>) == sizeof(scalar_t) * width);
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto* __restrict__ input_v =
reinterpret_cast<const _f16VecPN<scalar_t, width>*>(input);
const int vec_hidden_size = hidden_size / width;
float variance = 0.0f;
float variance2 = 0.0f;
float variance3 = 0.0f;
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16VecPN<scalar_t, width> temp = input_v[id];
auto [x2, x4, x6] = temp.sum_pows();
variance += x2;
variance2 += x4;
variance3 += x6;
}
float3 thread_variances = make_float3(variance, variance2, variance3);
struct SumOp {
__device__ float3 operator()(const float3& a, const float3& b) const {
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
}
};
using BlockReduce = cub::BlockReduce<float3, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
float3 block_variances =
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
variance = block_variances.x;
variance2 = block_variances.y;
variance3 = block_variances.z;
__shared__ float s_w2_inv_std;
__shared__ float s_w1_inv_std2;
__shared__ float s_w0_inv_std3;
__shared__ float s_bias;
if (threadIdx.x == 0) {
float w0 = (float)weight[0];
float w1 = (float)weight[1];
float w2 = (float)weight[2];
s_bias = (float)bias[0];
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
}
__syncthreads();
auto* __restrict__ out_v = reinterpret_cast<_f16VecPN<scalar_t, width>*>(out);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16VecPN<scalar_t, width> temp = input_v[id];
temp.poly_norm_inplace(s_w2_inv_std, s_w1_inv_std2, s_w0_inv_std3, s_bias);
out_v[id] = temp;
}
}
/* Generic poly_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [3]
const scalar_t* __restrict__ bias, // [1]
const float epsilon, const int hidden_size) {
float variance = 0.0f;
float variance2 = 0.0f;
float variance3 = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
float x2 = x * x;
float x4 = x2 * x2;
float x6 = x4 * x2;
variance += x2;
variance2 += x4;
variance3 += x6;
}
float3 thread_variances = make_float3(variance, variance2, variance3);
struct SumOp {
__device__ float3 operator()(const float3& a, const float3& b) const {
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
}
};
using BlockReduce = cub::BlockReduce<float3, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
float3 block_variances =
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
variance = block_variances.x;
variance2 = block_variances.y;
variance3 = block_variances.z;
__shared__ float s_w2_inv_std;
__shared__ float s_w1_inv_std2;
__shared__ float s_w0_inv_std3;
__shared__ float s_bias;
if (threadIdx.x == 0) {
float w0 = (float)weight[0];
float w1 = (float)weight[1];
float w2 = (float)weight[2];
s_bias = (float)bias[0];
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
float x2 = x * x;
float x3 = x2 * x;
out[blockIdx.x * hidden_size + idx] =
(scalar_t)(x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 +
s_bias);
}
}
} // namespace vllm } // namespace vllm
void rms_norm(torch::Tensor& out, // [..., hidden_size] void rms_norm(torch::Tensor& out, // [..., hidden_size]
...@@ -219,3 +424,49 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] ...@@ -219,3 +424,49 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
LAUNCH_FUSED_ADD_RMS_NORM(0); LAUNCH_FUSED_ADD_RMS_NORM(0);
} }
} }
#define LAUNCH_FUSED_POLY_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \
vllm::poly_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), epsilon, \
hidden_size); \
});
void poly_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [3]
torch::Tensor& bias, // [1]
double epsilon) {
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.data_ptr() != input.data_ptr());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_POLY_NORM(8);
} else {
LAUNCH_FUSED_POLY_NORM(0);
}
}
...@@ -27,11 +27,12 @@ ...@@ -27,11 +27,12 @@
template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_, template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
bool kIsVariableB_, bool kIsVariableC_, bool kIsVariableB_, bool kIsVariableC_,
bool kHasZ_, bool kVarlen_, typename input_t_, typename weight_t_> bool kHasZ_, bool kVarlen_, typename input_t_, typename weight_t_, typename state_t_>
struct Selective_Scan_fwd_kernel_traits { struct Selective_Scan_fwd_kernel_traits {
static_assert(kNItems_ % 4 == 0); static_assert(kNItems_ % 4 == 0);
using input_t = input_t_; using input_t = input_t_;
using weight_t = weight_t_; using weight_t = weight_t_;
using state_t = state_t_;
static constexpr int kNThreads = kNThreads_; static constexpr int kNThreads = kNThreads_;
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
...@@ -132,7 +133,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { ...@@ -132,7 +133,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride; input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride; weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride; input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) + typename Ktraits::state_t *ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
cache_index * params.ssm_states_batch_stride + cache_index * params.ssm_states_batch_stride +
dim_id * kNRows * params.ssm_states_dim_stride; dim_id * kNRows * params.ssm_states_dim_stride;
...@@ -261,7 +262,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { ...@@ -261,7 +262,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
smem_running_prefix[state_idx] = prefix_op.running_prefix; smem_running_prefix[state_idx] = prefix_op.running_prefix;
if (chunk == n_chunks - 1) { if (chunk == n_chunks - 1) {
ssm_states[state_idx * params.ssm_states_dstate_stride] = input_t(prefix_op.running_prefix.y); ssm_states[state_idx * params.ssm_states_dstate_stride] = typename Ktraits::state_t(prefix_op.running_prefix.y);
} }
} }
#pragma unroll #pragma unroll
...@@ -310,7 +311,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { ...@@ -310,7 +311,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
} }
} }
template<int kNThreads, int kNItems, typename input_t, typename weight_t> template<int kNThreads, int kNItems, typename input_t, typename weight_t, typename state_t>
void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) { void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
// processing 1 row. // processing 1 row.
...@@ -321,7 +322,7 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) { ...@@ -321,7 +322,7 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] { BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>; using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t, state_t>;
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
dim3 grid(params.batch, params.dim / kNRows); dim3 grid(params.batch, params.dim / kNRows);
auto kernel = &selective_scan_fwd_kernel<Ktraits>; auto kernel = &selective_scan_fwd_kernel<Ktraits>;
...@@ -341,59 +342,78 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) { ...@@ -341,59 +342,78 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
}); });
} }
template<typename input_t, typename weight_t> template<typename input_t, typename weight_t, typename state_t>
void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) { void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
#ifndef USE_ROCM #ifndef USE_ROCM
if (params.seqlen <= 128) { if (params.seqlen <= 128) {
selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); selective_scan_fwd_launch<32, 4, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 256) { } else if (params.seqlen <= 256) {
selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); selective_scan_fwd_launch<32, 8, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 512) { } else if (params.seqlen <= 512) {
selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); selective_scan_fwd_launch<32, 16, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 1024) { } else if (params.seqlen <= 1024) {
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
} else { } else {
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream);
} }
#else #else
if (params.seqlen <= 256) { if (params.seqlen <= 256) {
selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream); selective_scan_fwd_launch<64, 4, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 512) { } else if (params.seqlen <= 512) {
selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream); selective_scan_fwd_launch<64, 8, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 1024) { } else if (params.seqlen <= 1024) {
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
} else { } else {
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream);
} }
#endif #endif
} }
template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase &params, cudaStream_t stream); template void selective_scan_fwd_cuda<at::BFloat16, float, at::BFloat16>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase &params, cudaStream_t stream); template void selective_scan_fwd_cuda<at::BFloat16, float, float>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<float, float>(SSMParamsBase &params, cudaStream_t stream); template void selective_scan_fwd_cuda<at::Half, float, at::Half>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<at::Half, float, float>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<float, float, float>(SSMParamsBase &params, cudaStream_t stream);
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ #define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, STYPE, NAME, ...) \
if (ITYPE == at::ScalarType::Half) { \ if (ITYPE == at::ScalarType::Half) { \
using input_t = at::Half; \ using input_t = at::Half; \
using weight_t = float; \ using weight_t = float; \
__VA_ARGS__(); \ if (STYPE == at::ScalarType::Half) { \
using state_t = at::Half; \
__VA_ARGS__(); \
} else if (STYPE == at::ScalarType::Float) { \
using state_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
} \
} else if (ITYPE == at::ScalarType::BFloat16) { \ } else if (ITYPE == at::ScalarType::BFloat16) { \
using input_t = at::BFloat16; \ using input_t = at::BFloat16; \
using weight_t = float; \ using weight_t = float; \
__VA_ARGS__(); \ if (STYPE == at::ScalarType::BFloat16) { \
using state_t = at::BFloat16; \
__VA_ARGS__(); \
} else if (STYPE == at::ScalarType::Float) { \
using state_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
} \
} else if (ITYPE == at::ScalarType::Float) { \ } else if (ITYPE == at::ScalarType::Float) { \
using input_t = float; \ using input_t = float; \
using weight_t = float; \ using weight_t = float; \
using state_t = float; \
__VA_ARGS__(); \ __VA_ARGS__(); \
} else { \ } else { \
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
} }
template<typename input_t, typename weight_t> template<typename input_t, typename weight_t, typename state_t>
void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream); void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);
void set_ssm_params_fwd(SSMParamsBase &params, void set_ssm_params_fwd(SSMParamsBase &params,
...@@ -648,7 +668,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, ...@@ -648,7 +668,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
at::Tensor out = delta; at::Tensor out = delta;
TORCH_CHECK(ssm_states.scalar_type() == input_type); // ssm_states can now be either the same as input_type or float32
auto state_type = ssm_states.scalar_type();
TORCH_CHECK(state_type == input_type || state_type == at::ScalarType::Float);
TORCH_CHECK(ssm_states.is_cuda()); TORCH_CHECK(ssm_states.is_cuda());
TORCH_CHECK(ssm_states.stride(-1) == 1); TORCH_CHECK(ssm_states.stride(-1) == 1);
...@@ -670,7 +692,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, ...@@ -670,7 +692,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const at::cuda::OptionalCUDAGuard device_guard(device_of(u)); const at::cuda::OptionalCUDAGuard device_guard(device_of(u));
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), ssm_states.scalar_type(), "selective_scan_fwd", [&] {
selective_scan_fwd_cuda<input_t, weight_t>(params, stream); selective_scan_fwd_cuda<input_t, weight_t, state_t>(params, stream);
}); });
} }
...@@ -28,6 +28,7 @@ namespace cg = cooperative_groups; ...@@ -28,6 +28,7 @@ namespace cg = cooperative_groups;
namespace vllm { namespace vllm {
namespace moe { namespace moe {
constexpr float kNegInfinity = INFINITY * -1;
constexpr unsigned FULL_WARP_MASK = 0xffffffff; constexpr unsigned FULL_WARP_MASK = 0xffffffff;
constexpr int32_t WARP_SIZE = 32; constexpr int32_t WARP_SIZE = 32;
constexpr int32_t BLOCK_SIZE = 512; constexpr int32_t BLOCK_SIZE = 512;
...@@ -512,8 +513,8 @@ __global__ void group_idx_and_topk_idx_kernel( ...@@ -512,8 +513,8 @@ __global__ void group_idx_and_topk_idx_kernel(
warp_id * topk; warp_id * topk;
s_topk_idx += warp_id * topk; s_topk_idx += warp_id * topk;
T value = cuda::std::numeric_limits<T>::min(); T value = kNegInfinity;
T topk_group_value = cuda::std::numeric_limits<T>::min(); T topk_group_value = kNegInfinity;
int32_t num_equalto_topkth_group; int32_t num_equalto_topkth_group;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
...@@ -539,11 +540,11 @@ __global__ void group_idx_and_topk_idx_kernel( ...@@ -539,11 +540,11 @@ __global__ void group_idx_and_topk_idx_kernel(
__syncwarp(); // Ensure all threads have valid data before reduction __syncwarp(); // Ensure all threads have valid data before reduction
topk_group_value = cg::reduce(tile, value, cg::greater<T>()); topk_group_value = cg::reduce(tile, value, cg::greater<T>());
if (value == topk_group_value) { if (value == topk_group_value) {
value = cuda::std::numeric_limits<T>::min(); value = kNegInfinity;
} }
pre_count_equal_to_top_value = count_equal_to_top_value; pre_count_equal_to_top_value = count_equal_to_top_value;
count_equal_to_top_value = __popc(__ballot_sync( count_equal_to_top_value = __popc(__ballot_sync(
FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min()))); FULL_WARP_MASK, (value == cuda_cast<T, float>(kNegInfinity))));
} }
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
} }
...@@ -555,7 +556,7 @@ __global__ void group_idx_and_topk_idx_kernel( ...@@ -555,7 +556,7 @@ __global__ void group_idx_and_topk_idx_kernel(
int count_equalto_topkth_group = 0; int count_equalto_topkth_group = 0;
bool if_proceed_next_topk = bool if_proceed_next_topk =
(topk_group_value != cuda::std::numeric_limits<T>::min()); (topk_group_value != cuda_cast<T, float>(kNegInfinity));
if (case_id < num_tokens && if_proceed_next_topk) { if (case_id < num_tokens && if_proceed_next_topk) {
for (int i_group = 0; i_group < n_group; i_group++) { for (int i_group = 0; i_group < n_group; i_group++) {
if ((group_scores[i_group] > topk_group_value) || if ((group_scores[i_group] > topk_group_value) ||
...@@ -568,7 +569,7 @@ __global__ void group_idx_and_topk_idx_kernel( ...@@ -568,7 +569,7 @@ __global__ void group_idx_and_topk_idx_kernel(
(i < num_experts_per_group) && isfinite(cuda_cast<float, T>( (i < num_experts_per_group) && isfinite(cuda_cast<float, T>(
scores_with_bias[offset + i])) scores_with_bias[offset + i]))
? scores_with_bias[offset + i] ? scores_with_bias[offset + i]
: cuda::std::numeric_limits<T>::min(); : cuda_cast<T, float>(kNegInfinity);
queue.add(candidates, offset + i); queue.add(candidates, offset + i);
} }
if (group_scores[i_group] == topk_group_value) { if (group_scores[i_group] == topk_group_value) {
......
...@@ -92,6 +92,9 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, ...@@ -92,6 +92,9 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon); torch::Tensor& weight, double epsilon);
void poly_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
torch::Tensor& bias, double epsilon);
void apply_repetition_penalties_(torch::Tensor& logits, void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& prompt_mask, const torch::Tensor& prompt_mask,
const torch::Tensor& output_mask, const torch::Tensor& output_mask,
...@@ -130,8 +133,7 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input); ...@@ -130,8 +133,7 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
// void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, // void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& scale); // torch::Tensor& scale);
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ #ifndef USE_ROCM
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_nvfp4_quant(torch::Tensor& out, void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch::Tensor& output_block_scale, torch::Tensor& output_block_scale,
torch::Tensor& input, torch::Tensor& input,
...@@ -356,4 +358,4 @@ void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles); ...@@ -356,4 +358,4 @@ void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int64_t quant_level, bool cast_bf2half = false); int64_t quant_level, bool cast_bf2half = false);
int64_t qr_max_size(); int64_t qr_max_size();
#endif #endif
\ No newline at end of file
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "core/registration.h" #include "core/registration.h"
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include <limits>
#include "cute/tensor.hpp" #include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp"
...@@ -169,6 +170,11 @@ struct W4A8GemmKernel { ...@@ -169,6 +170,11 @@ struct W4A8GemmKernel {
int k = A.size(1); int k = A.size(1);
int n = B.size(1); int n = B.size(1);
// safely cast group_size to int
TORCH_CHECK(group_size > 0 && group_size <= std::numeric_limits<int>::max(),
"group_size out of supported range for int: ", group_size);
int const group_size_int = static_cast<int>(group_size);
// Allocate output // Allocate output
const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
auto device = A.device(); auto device = A.device();
...@@ -181,7 +187,7 @@ struct W4A8GemmKernel { ...@@ -181,7 +187,7 @@ struct W4A8GemmKernel {
auto A_ptr = static_cast<MmaType const*>(A.const_data_ptr()); auto A_ptr = static_cast<MmaType const*>(A.const_data_ptr());
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr()); auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
auto D_ptr = static_cast<ElementD*>(D.data_ptr()); auto D_ptr = static_cast<ElementD*>(D.data_ptr());
// can we avoid harcode the 8 here // can we avoid hardcode the 8 here
auto S_ptr = auto S_ptr =
static_cast<cutlass::Array<ElementScale, ScalePackSize> const*>( static_cast<cutlass::Array<ElementScale, ScalePackSize> const*>(
group_scales.const_data_ptr()); group_scales.const_data_ptr());
...@@ -192,7 +198,7 @@ struct W4A8GemmKernel { ...@@ -192,7 +198,7 @@ struct W4A8GemmKernel {
cute::tile_to_shape(LayoutAtomQuant{}, shape_B); cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
// strides // strides
int const scale_k = cutlass::ceil_div(k, group_size); int const scale_k = cutlass::ceil_div(k, group_size_int);
StrideA stride_A = StrideA stride_A =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
// Reverse stride here due to swap and transpose // Reverse stride here due to swap and transpose
...@@ -211,8 +217,8 @@ struct W4A8GemmKernel { ...@@ -211,8 +217,8 @@ struct W4A8GemmKernel {
using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments; using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments;
MainloopArguments mainloop_arguments{ MainloopArguments mainloop_arguments{
B_ptr, layout_B_reordered, A_ptr, stride_A, B_ptr, layout_B_reordered, A_ptr, stride_A,
S_ptr, stride_S, group_size}; S_ptr, stride_S, group_size_int};
EpilogueArguments epilogue_arguments{ EpilogueArguments epilogue_arguments{
ChTokScalesEpilogue::prepare_args(channel_scales, token_scales), ChTokScalesEpilogue::prepare_args(channel_scales, token_scales),
......
...@@ -14,9 +14,6 @@ ...@@ -14,9 +14,6 @@
#include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh" #include "cutlass_gemm_caller.cuh"
namespace vllm { namespace vllm {
......
...@@ -14,9 +14,6 @@ ...@@ -14,9 +14,6 @@
#include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh" #include "cutlass_gemm_caller.cuh"
namespace vllm { namespace vllm {
......
...@@ -13,27 +13,18 @@ ...@@ -13,27 +13,18 @@
#include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh" #include "cutlass_gemm_caller.cuh"
namespace vllm { namespace vllm {
using namespace cute; using namespace cute;
template <typename SchedulerType, typename OutType, int GroupSizeM_, // clang-format off
int GroupSizeN_, int GroupSizeK_, int TileSizeM_ = 128, template <class OutType, int ScaleGranularityM,
class ClusterShape = Shape<_1, _2, _1>> int ScaleGranularityN, int ScaleGranularityK,
class MmaTileShape, class ClusterShape,
class EpilogueScheduler, class MainloopScheduler>
struct cutlass_3x_gemm_fp8_blockwise { struct cutlass_3x_gemm_fp8_blockwise {
using GroupSizeM = Int<GroupSizeM_>;
using GroupSizeN = Int<GroupSizeN_>;
using GroupSizeK = Int<GroupSizeK_>;
using TileSizeM = Int<TileSizeM_>;
static_assert(TileSizeM_ % GroupSizeM_ == 0,
"TileSizeM must be a multiple of GroupSizeM");
using ElementAB = cutlass::float_e4m3_t; using ElementAB = cutlass::float_e4m3_t;
using ElementA = ElementAB; using ElementA = ElementAB;
...@@ -45,52 +36,67 @@ struct cutlass_3x_gemm_fp8_blockwise { ...@@ -45,52 +36,67 @@ struct cutlass_3x_gemm_fp8_blockwise {
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using ElementD = OutType; using ElementD = OutType;
using StrideD = Stride<int64_t, Int<1>, Int<0>>; using LayoutD = cutlass::layout::RowMajor;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
using ElementC = void; using ElementC = void; // TODO: support bias
using StrideC = StrideD; using LayoutC = LayoutD;
static constexpr int AlignmentC = AlignmentD; static constexpr int AlignmentC = AlignmentD;
using ElementAccumulator = float; using ElementAccumulator = float;
using ElementBlockScale = float;
using ElementCompute = float; using ElementCompute = float;
using ElementBlockScale = float;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
using ArchTag = cutlass::arch::Sm90; using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp; using OperatorClass = cutlass::arch::OpClassTensorOp;
using TileShape = Shape<TileSizeM, GroupSizeN, GroupSizeK>;
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using KernelSchedule = cutlass::gemm:: using ElementScalar = float;
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
GroupSizeM_>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; ArchTag,
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; OperatorClass,
MmaTileShape,
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT< ClusterShape,
cutlass::epilogue::fusion::Sm90AccFetch>; cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
using CollectiveEpilogue = ElementCompute,
typename cutlass::epilogue::collective::CollectiveBuilder< ElementC,
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, LayoutC,
ElementAccumulator, ElementCompute, ElementC, StrideC, AlignmentC, AlignmentC,
ElementD, StrideD, AlignmentD, EpilogueSchedule, ElementD,
StoreEpilogueCompute>::CollectiveOp; LayoutD,
AlignmentD,
using CollectiveMainloop = EpilogueScheduler,
typename cutlass::gemm::collective::CollectiveBuilder< DefaultOperation
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, >::CollectiveOp;
LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
sizeof(typename CollectiveEpilogue::SharedStorage))>, ArchTag,
KernelSchedule>::CollectiveOp; OperatorClass,
ElementA,
cute::tuple<LayoutA, LayoutSFA>,
AlignmentA,
ElementB,
cute::tuple<LayoutB, LayoutSFB>,
AlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduler
>::CollectiveOp;
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal< using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
SchedulerType>>;
struct GemmKernel : public KernelType {}; struct GemmKernel : public KernelType {};
using StrideA = typename GemmKernel::StrideA;
using StrideB = typename GemmKernel::StrideB;
}; };
template <typename Gemm> template <typename Gemm>
...@@ -99,76 +105,54 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, ...@@ -99,76 +105,54 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel; using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideD = typename Gemm::GemmKernel::StrideD;
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutSFA = typename Gemm::LayoutSFA;
using LayoutSFB = typename Gemm::LayoutSFB;
using ScaleConfig = typename Gemm::ScaleConfig;
using ElementAB = typename Gemm::ElementAB; using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD; using ElementD = typename Gemm::ElementD;
auto prob_shape = c3x::get_problem_shape(a, b); int32_t m = a.size(0), n = b.size(1), k = a.size(1);
int32_t m = get<0>(prob_shape), n = get<1>(prob_shape),
k = get<2>(prob_shape);
int64_t lda = a.stride(0); TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);
using StrideA = Stride<int64_t, Int<1>, int64_t>; StrideA a_stride;
using StrideB = Stride<int64_t, Int<1>, int64_t>; StrideB b_stride;
using StrideC = typename Gemm::StrideC; StrideC c_stride;
a_stride =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
b_stride =
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
c_stride =
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
StrideA a_stride{lda, Int<1>{}, 0}; LayoutSFA layout_SFA =
StrideB b_stride{ldb, Int<1>{}, 0}; ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; LayoutSFB layout_SFB =
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
auto a_ptr = static_cast<ElementAB*>(a.data_ptr()); auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB*>(b.data_ptr()); auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr()); auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr()); auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
// Check is the t is contiguous and is 1D or 2D with one of the dimensions auto mainloop_args = [&](){
// being 1 (i.e. a row or column vector) return typename GemmKernel::MainloopArguments{
auto is_contiguous_vector = [](const torch::Tensor& t) { a_ptr, a_stride, b_ptr, b_stride,
auto t_sizes = t.sizes(); a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
return t.is_contiguous() && };
(t.dim() == 1 || }();
(t.dim() == 2 && auto prob_shape = cute::make_shape(m, n, k, 1);
*std::min_element(t_sizes.begin(), t_sizes.end()) == 1));
};
// TODO(lucas): lets clean-up the kernel so that we pass in Strides so
// we don't have to deal with enforcing implicit layouts
TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value);
TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value);
TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales),
"a_scales must be M major");
TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value);
TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value);
TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales),
"b_scales must be K major");
typename GemmKernel::MainloopArguments mainloop_args{
a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr};
auto c_ptr = static_cast<ElementD*>(out.data_ptr()); auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{ typename GemmKernel::EpilogueArguments epilogue_args{
{}, c_ptr, c_stride, c_ptr, c_stride}; {}, c_ptr, c_stride, c_ptr, c_stride};
typename GemmKernel::TileSchedulerArguments scheduler;
static constexpr bool UsesStreamKScheduler =
cute::is_same_v<typename GemmKernel::TileSchedulerTag,
cutlass::gemm::StreamKScheduler>;
if constexpr (UsesStreamKScheduler) {
using DecompositionMode = typename cutlass::gemm::kernel::detail::
PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
using ReductionMode = typename cutlass::gemm::kernel::detail::
PersistentTileSchedulerSm90StreamKParams::ReductionMode;
scheduler.decomposition_mode = DecompositionMode::StreamK;
scheduler.reduction_mode = ReductionMode::Nondeterministic;
}
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args, c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
epilogue_args, scheduler); epilogue_args);
} }
template <typename OutType> template <typename OutType>
...@@ -177,18 +161,12 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, ...@@ -177,18 +161,12 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::Tensor const& b_scales) {
auto k = a.size(1); // TODO: better heuristics
auto n = b.size(1); cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, 128, 128, Shape<_128, _128, _128>,
if (k > 3 * n) { Shape<_1, _2, _1>, cutlass::epilogue::TmaWarpSpecializedCooperative,
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>>(
cutlass::gemm::StreamKScheduler, OutType, 1, 128, 128>>( out, a, b, a_scales, b_scales);
out, a, b, a_scales, b_scales);
} else {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
cutlass::gemm::PersistentScheduler, OutType, 1, 128, 128>>(
out, a, b, a_scales, b_scales);
}
} }
} // namespace vllm } // namespace vllm
\ No newline at end of file
...@@ -32,7 +32,7 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, ...@@ -32,7 +32,7 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor."); TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor."); TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
if (version_num >= 100) { if (version_num >= 90) {
TORCH_CHECK( TORCH_CHECK(
a.size(0) == a_scales.size(0) && a.size(0) == a_scales.size(0) &&
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1), cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
...@@ -41,32 +41,6 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, ...@@ -41,32 +41,6 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) && cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1), cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
"b_scale_group_shape must be [128, 128]."); "b_scale_group_shape must be [128, 128].");
} else {
// TODO: Remove this after using cutlass sm90 blockwise scaling gemm
// kernel, or introducing ceil_div to the load_init() of mainloop.
using GroupShape = std::array<int64_t, 2>;
auto make_group_shape = [](torch::Tensor const& x,
torch::Tensor const& s) -> GroupShape {
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
cuda_utils::ceil_div(x.size(1), s.size(1))};
};
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
b_scale_group_shape == GroupShape{128, 128} &&
a.dtype() == torch::kFloat8_e4m3fn &&
b.dtype() == torch::kFloat8_e4m3fn),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
"a_scale_group_shape must be [1, 128]. Got: [",
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
"]\n"
"b_scale_group_shape must be [128, 128]. Got: [",
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
} }
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
......
...@@ -26,164 +26,17 @@ ...@@ -26,164 +26,17 @@
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "cuda_utils.h" #include "cuda_utils.h"
#include "nvfp4_utils.cuh"
namespace vllm { namespace vllm {
// Get type2 from type or vice versa (applied to half and bfloat16)
template <typename T>
struct TypeConverter {
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2> {
using Type = c10::Half;
};
template <>
struct TypeConverter<c10::Half> {
using Type = half2;
};
template <>
struct TypeConverter<__nv_bfloat162> {
using Type = c10::BFloat16;
};
template <>
struct TypeConverter<c10::BFloat16> {
using Type = __nv_bfloat162;
};
#define ELTS_PER_THREAD 8
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]),
"f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7]));
return val;
#else
return 0;
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
return val;
#else
return 0;
#endif
}
// Fast reciprocal.
inline __device__ float reciprocal_approximate_ftz(float a) {
float b;
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
return b;
}
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
int numCols,
SFType* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
CVT_FP4_NUM_THREADS_PER_SF == 2);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
// SF vector index (16 elements share one SF in the K dimension).
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
int32_t mIdx = rowIdx;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t mTileIdx = mIdx / (32 * 4);
// SF vector size 16.
int factor = CVT_FP4_SF_VEC_SIZE * 4;
int32_t numKTiles = (numCols + factor - 1) / factor;
int64_t mTileStride = numKTiles * 32 * 4 * 4;
int32_t kTileIdx = (kIdx / 4);
int64_t kTileStride = 32 * 4 * 4;
// M tile layout [32, 4] is column-major.
int32_t outerMIdx = (mIdx % 32);
int64_t outerMStride = 4 * 4;
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
int64_t innerMStride = 4;
int32_t innerKIdx = (kIdx % 4);
int64_t innerKStride = 1;
// Compute the global offset.
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
outerMIdx * outerMStride + innerMIdx * innerMStride +
innerKIdx * innerKStride;
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
}
#endif
return nullptr;
}
// Define a 16 bytes packed data type.
template <class Type>
struct PackedVec {
typename TypeConverter<Type>::Type elts[4];
};
template <>
struct PackedVec<__nv_fp8_e4m3> {
__nv_fp8x2_e4m3 elts[8];
};
template <class Type> template <class Type>
__inline__ __device__ PackedVec<Type> compute_silu(PackedVec<Type>& vec, __inline__ __device__ PackedVec<Type> compute_silu(PackedVec<Type>& vec,
PackedVec<Type>& vec2) { PackedVec<Type>& vec2) {
PackedVec<Type> result; PackedVec<Type> result;
#pragma unroll #pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
if constexpr (std::is_same_v<Type, c10::Half>) { if constexpr (std::is_same_v<Type, half>) {
half2 val(0.5f, 0.5f); half2 val(0.5f, 0.5f);
half2 t0 = __hmul2(vec.elts[i], val); half2 t0 = __hmul2(vec.elts[i], val);
half2 t1 = __hfma2(h2tanh(t0), val, val); half2 t1 = __hfma2(h2tanh(t0), val, val);
...@@ -206,13 +59,12 @@ __device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, ...@@ -206,13 +59,12 @@ __device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec,
PackedVec<Type>& vec2, PackedVec<Type>& vec2,
float SFScaleVal, float SFScaleVal,
uint8_t* SFout) { uint8_t* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
PackedVec<Type> out_silu = compute_silu(vec, vec2); PackedVec<Type> out_silu = compute_silu(vec, vec2);
// Get absolute maximum values among the local 8 values. // Get absolute maximum values among the local 8 values.
auto localMax = __habs2(out_silu.elts[0]); auto localMax = __habs2(out_silu.elts[0]);
// Local maximum value. // Local maximum value.
#pragma unroll #pragma unroll
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
localMax = __hmax2(localMax, __habs2(out_silu.elts[i])); localMax = __hmax2(localMax, __habs2(out_silu.elts[i]));
} }
...@@ -259,9 +111,9 @@ __device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, ...@@ -259,9 +111,9 @@ __device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec,
// Convert the input to float. // Convert the input to float.
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
#pragma unroll #pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, c10::Half>) { if constexpr (std::is_same_v<Type, half>) {
fp2Vals[i] = __half22float2(out_silu.elts[i]); fp2Vals[i] = __half22float2(out_silu.elts[i]);
} else { } else {
fp2Vals[i] = __bfloat1622float2(out_silu.elts[i]); fp2Vals[i] = __bfloat1622float2(out_silu.elts[i]);
...@@ -275,22 +127,14 @@ __device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, ...@@ -275,22 +127,14 @@ __device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec,
// Write the e2m1 values to global memory. // Write the e2m1 values to global memory.
return e2m1Vec; return e2m1Vec;
#else
return 0;
#endif
} }
// Use UE4M3 by default. // Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false> template <class Type, bool UE8M0_SF = false>
__global__ void __global__ void __launch_bounds__(1024, 4)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) silu_and_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
__launch_bounds__(1024, 4) silu_and_cvt_fp16_to_fp4( float const* SFScale, uint32_t* out,
#else uint32_t* SFout) {
silu_and_cvt_fp16_to_fp4(
#endif
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
uint32_t* out, uint32_t* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>; using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
...@@ -328,22 +172,25 @@ silu_and_cvt_fp16_to_fp4( ...@@ -328,22 +172,25 @@ silu_and_cvt_fp16_to_fp4(
in_vec, in_vec2, SFScaleVal, sf_out); in_vec, in_vec2, SFScaleVal, sf_out);
} }
} }
#endif
} }
} // namespace vllm } // namespace vllm
void silu_and_mul_nvfp4_quant(torch::Tensor& output, // [..., d] void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
torch::Tensor& output_sf, torch::Tensor& output_sf,
torch::Tensor& input, // [..., 2 * d] torch::Tensor& input, // [..., 2 * d]
torch::Tensor& input_sf) { torch::Tensor& input_sf) {
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
input.dtype() == torch::kBFloat16);
int32_t m = input.size(0); int32_t m = input.size(0);
int32_t n = input.size(1) / 2; int32_t n = input.size(1) / 2;
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
input.scalar_type() == at::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
int multiProcessorCount = int multiProcessorCount =
get_device_attribute(cudaDevAttrMultiProcessorCount, -1); get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr()); auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr()); auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr()); auto output_ptr = static_cast<int64_t*>(output.data_ptr());
...@@ -352,17 +199,14 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& output, // [..., d] ...@@ -352,17 +199,14 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& output, // [..., d]
dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024)); dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024));
int const numBlocksPerSM = 2048 / block.x; int const numBlocksPerSM = 2048 / block.x;
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
VLLM_DISPATCH_HALF_TYPES( VLLM_DISPATCH_HALF_TYPES(
input.scalar_type(), "act_and_mul_quant_kernel", [&] { input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] {
auto input_ptr = reinterpret_cast<scalar_t const*>(input.data_ptr()); using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
VLLM_DISPATCH_BYTE_TYPES( auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
output.scalar_type(), "fused_act_and_mul_quant_kernel_nvfp4_type", vllm::silu_and_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>(
[&] { m, n, input_ptr, input_sf_ptr,
vllm::silu_and_cvt_fp16_to_fp4<scalar_t> reinterpret_cast<uint32_t*>(output_ptr),
<<<grid, block, 0, stream>>>( reinterpret_cast<uint32_t*>(sf_out));
m, n, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
}); });
} }
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h> #include <torch/all.h>
#include <cutlass/arch/arch.h> #include <cutlass/arch/arch.h>
......
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h> #include <torch/all.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include "dispatch_utils.h"
template <typename T> #include "nvfp4_utils.cuh"
struct TypeConverter {
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2> {
using Type = half;
};
template <>
struct TypeConverter<half> {
using Type = half2;
};
template <>
struct TypeConverter<__nv_bfloat162> {
using Type = __nv_bfloat16;
};
template <>
struct TypeConverter<__nv_bfloat16> {
using Type = __nv_bfloat162;
};
#define ELTS_PER_THREAD 8
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]),
"f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7]));
return val;
#else
return 0;
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
return val;
#else
return 0;
#endif
}
// Fast reciprocal.
inline __device__ float reciprocal_approximate_ftz(float a) {
float b;
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
return b;
}
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF> namespace vllm {
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
int numCols,
SFType* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
CVT_FP4_NUM_THREADS_PER_SF == 2);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
// SF vector index (16 elements share one SF in the K dimension).
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
int32_t mIdx = rowIdx;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t mTileIdx = mIdx / (32 * 4);
// SF vector size 16.
int factor = CVT_FP4_SF_VEC_SIZE * 4;
int32_t numKTiles = (numCols + factor - 1) / factor;
int64_t mTileStride = numKTiles * 32 * 4 * 4;
int32_t kTileIdx = (kIdx / 4);
int64_t kTileStride = 32 * 4 * 4;
// M tile layout [32, 4] is column-major.
int32_t outerMIdx = (mIdx % 32);
int64_t outerMStride = 4 * 4;
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
int64_t innerMStride = 4;
int32_t innerKIdx = (kIdx % 4);
int64_t innerKStride = 1;
// Compute the global offset.
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
outerMIdx * outerMStride + innerMIdx * innerMStride +
innerKIdx * innerKStride;
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
}
#endif
return nullptr;
}
// Define a 16 bytes packed data type.
template <class Type>
struct PackedVec {
typename TypeConverter<Type>::Type elts[4];
};
template <>
struct PackedVec<__nv_fp8_e4m3> {
__nv_fp8x2_e4m3 elts[8];
};
// Quantizes the provided PackedVec into the uint32_t output
template <class Type, bool UE8M0_SF = false>
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
uint8_t* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto localMax = __habs2(vec.elts[0]);
// Local maximum value.
#pragma unroll
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
// Get the final absolute maximum values.
float vecMax = float(__hmax(localMax.x, localMax.y));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
// 8 bits representation of the SF.
uint8_t fp8SFVal;
// Write the SF to global memory (STG.8).
if constexpr (UE8M0_SF) {
// Extract the 8 exponent bits from float32.
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
fp8SFVal = tmp & 0xff;
// Convert back to fp32.
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
} else {
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
// Convert back to fp32.
SFValue = float(tmp);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float outputScale =
SFValue != 0 ? reciprocal_approximate_ftz(
SFValue * reciprocal_approximate_ftz(SFScaleVal))
: 0.0f;
if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}
// Convert the input to float.
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) {
fp2Vals[i] = __half22float2(vec.elts[i]);
} else {
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}
// Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
// Write the e2m1 values to global memory.
return e2m1Vec;
#else
return 0;
#endif
}
// Use UE4M3 by default. // Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false> template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
__global__ void __global__ void __launch_bounds__(512, 4)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
__launch_bounds__(512, 4) cvt_fp16_to_fp4( float const* SFScale, uint32_t* out, uint32_t* SFout,
#else uint32_t* input_offset_by_experts,
cvt_fp16_to_fp4( uint32_t* output_scale_offset_by_experts, int n_experts,
#endif bool low_latency) {
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts,
uint32_t* output_scale_offset_by_experts, int n_experts, bool low_latency) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>; using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
...@@ -299,8 +94,8 @@ cvt_fp16_to_fp4( ...@@ -299,8 +94,8 @@ cvt_fp16_to_fp4(
&input_offset_by_experts[chunk_start + 12])); &input_offset_by_experts[chunk_start + 12]));
local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]); local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]);
// Check against the 16 loaded offsets // Check against the 16 loaded offsets
#pragma unroll #pragma unroll
for (int i = 0; i < 16; i++) { for (int i = 0; i < 16; i++) {
if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) { if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) {
rowIdx_in_expert = rowIdx - local_offsets[i]; rowIdx_in_expert = rowIdx - local_offsets[i];
...@@ -330,21 +125,15 @@ cvt_fp16_to_fp4( ...@@ -330,21 +125,15 @@ cvt_fp16_to_fp4(
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out); out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
} }
#endif
} }
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version) // Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false> template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
__global__ void __global__ void __launch_bounds__(1024, 4)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
__launch_bounds__(1024, 4) cvt_fp16_to_fp4( float const* SFScale, uint32_t* out, uint32_t* SFout,
#else uint32_t* input_offset_by_experts,
cvt_fp16_to_fp4( uint32_t* output_scale_offset_by_experts, int n_experts) {
#endif
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts,
uint32_t* output_scale_offset_by_experts, int n_experts) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>; using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
...@@ -425,7 +214,6 @@ cvt_fp16_to_fp4( ...@@ -425,7 +214,6 @@ cvt_fp16_to_fp4(
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out); out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
} }
#endif
} }
template <typename T> template <typename T>
...@@ -501,6 +289,8 @@ void quant_impl(void* output, void* output_scale, void* input, ...@@ -501,6 +289,8 @@ void quant_impl(void* output, void* output_scale, void* input,
} }
} }
} // namespace vllm
/*Quantization entry for fp4 experts quantization*/ /*Quantization entry for fp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") #define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \ #define CHECK_CONTIGUOUS(x, m) \
...@@ -560,23 +350,17 @@ void scaled_fp4_experts_quant_sm100a( ...@@ -560,23 +350,17 @@ void scaled_fp4_experts_quant_sm100a(
// 4 means 4 fp8 values are packed into one int32 // 4 means 4 fp8 values are packed into one int32
TORCH_CHECK(output_scale.size(1) * 4 == padded_k); TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
auto in_dtype = input.dtype();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(input.get_device()); at::cuda::getCurrentCUDAStream(input.get_device());
if (in_dtype == at::ScalarType::Half) {
quant_impl<half>(output.data_ptr(), output_scale.data_ptr(), VLLM_DISPATCH_HALF_TYPES(
input.data_ptr(), input_global_scale.data_ptr(), input.scalar_type(), "nvfp4_experts_quant_kernel", [&] {
input_offset_by_experts.data_ptr(), using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
output_scale_offset_by_experts.data_ptr(), m_topk, k, vllm::quant_impl<cuda_type>(
n_experts, stream); output.data_ptr(), output_scale.data_ptr(), input.data_ptr(),
} else if (in_dtype == at::ScalarType::BFloat16) { input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(),
quant_impl<__nv_bfloat16>(output.data_ptr(), output_scale.data_ptr(), output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts,
input.data_ptr(), input_global_scale.data_ptr(), stream);
input_offset_by_experts.data_ptr(), });
output_scale_offset_by_experts.data_ptr(), m_topk,
k, n_experts, stream);
} else {
TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
}
} }
...@@ -32,6 +32,14 @@ void scaled_fp4_experts_quant_sm100a( ...@@ -32,6 +32,14 @@ void scaled_fp4_experts_quant_sm100a(
torch::Tensor const& output_scale_offset_by_experts); torch::Tensor const& output_scale_offset_by_experts);
#endif #endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output,
torch::Tensor& output_sf,
torch::Tensor& input,
torch::Tensor& input_sf);
#endif
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
torch::Tensor& output_sf, torch::Tensor const& input_sf) { torch::Tensor& output_sf, torch::Tensor const& input_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
...@@ -54,3 +62,13 @@ void scaled_fp4_experts_quant( ...@@ -54,3 +62,13 @@ void scaled_fp4_experts_quant(
TORCH_CHECK_NOT_IMPLEMENTED(false, TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled nvfp4 experts quantization kernel"); "No compiled nvfp4 experts quantization kernel");
} }
void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf,
torch::Tensor& input, torch::Tensor& input_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 quantization kernel");
}
...@@ -23,245 +23,18 @@ ...@@ -23,245 +23,18 @@
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h" #include "cuda_utils.h"
#include "nvfp4_utils.cuh"
// Get type2 from type or vice versa (applied to half and bfloat16) namespace vllm {
template <typename T>
struct TypeConverter {
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2> {
using Type = half;
};
template <>
struct TypeConverter<half> {
using Type = half2;
};
template <>
struct TypeConverter<__nv_bfloat162> {
using Type = __nv_bfloat16;
};
template <>
struct TypeConverter<__nv_bfloat16> {
using Type = __nv_bfloat162;
};
#define ELTS_PER_THREAD 8
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]),
"f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7]));
return val;
#else
return 0;
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
return val;
#else
return 0;
#endif
}
// Fast reciprocal.
inline __device__ float reciprocal_approximate_ftz(float a) {
float b;
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
return b;
}
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
int numCols,
SFType* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
CVT_FP4_NUM_THREADS_PER_SF == 2);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
// SF vector index (16 elements share one SF in the K dimension).
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
int32_t mIdx = rowIdx;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t mTileIdx = mIdx / (32 * 4);
// SF vector size 16.
int factor = CVT_FP4_SF_VEC_SIZE * 4;
int32_t numKTiles = (numCols + factor - 1) / factor;
int64_t mTileStride = numKTiles * 32 * 4 * 4;
int32_t kTileIdx = (kIdx / 4);
int64_t kTileStride = 32 * 4 * 4;
// M tile layout [32, 4] is column-major.
int32_t outerMIdx = (mIdx % 32);
int64_t outerMStride = 4 * 4;
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
int64_t innerMStride = 4;
int32_t innerKIdx = (kIdx % 4);
int64_t innerKStride = 1;
// Compute the global offset.
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
outerMIdx * outerMStride + innerMIdx * innerMStride +
innerKIdx * innerKStride;
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
}
#endif
return nullptr;
}
// Define a 16 bytes packed data type.
template <class Type>
struct PackedVec {
typename TypeConverter<Type>::Type elts[4];
};
template <>
struct PackedVec<__nv_fp8_e4m3> {
__nv_fp8x2_e4m3 elts[8];
};
// Quantizes the provided PackedVec into the uint32_t output
template <class Type, bool UE8M0_SF = false>
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
uint8_t* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto localMax = __habs2(vec.elts[0]);
// Local maximum value.
#pragma unroll
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
// Get the final absolute maximum values.
float vecMax = float(__hmax(localMax.x, localMax.y));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
// 8 bits representation of the SF.
uint8_t fp8SFVal;
// Write the SF to global memory (STG.8).
if constexpr (UE8M0_SF) {
// Extract the 8 exponent bits from float32.
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
fp8SFVal = tmp & 0xff;
// Convert back to fp32.
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
} else {
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
// Convert back to fp32.
SFValue = float(tmp);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float outputScale =
SFValue != 0 ? reciprocal_approximate_ftz(
SFValue * reciprocal_approximate_ftz(SFScaleVal))
: 0.0f;
if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}
// Convert the input to float.
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) {
fp2Vals[i] = __half22float2(vec.elts[i]);
} else {
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}
// Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
// Write the e2m1 values to global memory.
return e2m1Vec;
#else
return 0;
#endif
}
// Use UE4M3 by default. // Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false> template <class Type, bool UE8M0_SF = false>
__global__ void __global__ void __launch_bounds__(512, 4)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
__launch_bounds__(512, 4) cvt_fp16_to_fp4( float const* SFScale, uint32_t* out, uint32_t* SFout) {
#else
cvt_fp16_to_fp4(
#endif
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
uint32_t* out, uint32_t* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>; using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
...@@ -293,7 +66,6 @@ cvt_fp16_to_fp4( ...@@ -293,7 +66,6 @@ cvt_fp16_to_fp4(
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out); cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
} }
} }
#endif
} }
template <typename T> template <typename T>
...@@ -332,6 +104,8 @@ template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input, ...@@ -332,6 +104,8 @@ template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input,
int multiProcessorCount, int multiProcessorCount,
cudaStream_t stream); cudaStream_t stream);
} // namespace vllm
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
torch::Tensor const& input, torch::Tensor const& input,
torch::Tensor const& output_sf, torch::Tensor const& output_sf,
...@@ -340,6 +114,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, ...@@ -340,6 +114,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
int32_t n = input.size(1); int32_t n = input.size(1);
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
input.scalar_type() == at::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
int multiProcessorCount = int multiProcessorCount =
get_device_attribute(cudaDevAttrMultiProcessorCount, -1); get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
...@@ -353,24 +130,10 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, ...@@ -353,24 +130,10 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
// We don't support e8m0 scales at this moment. // We don't support e8m0 scales at this moment.
bool useUE8M0 = false; bool useUE8M0 = false;
switch (input.scalar_type()) { VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
case torch::kHalf: { using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = reinterpret_cast<half const*>(input.data_ptr()); auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, vllm::invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr,
useUE8M0, multiProcessorCount, stream); sf_out, useUE8M0, multiProcessorCount, stream);
break; });
}
case torch::kBFloat16: {
auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr());
invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out,
useUE8M0, multiProcessorCount, stream);
break;
}
default: {
std::cerr << "Observing: " << input.scalar_type()
<< " for the input datatype which is invalid";
throw std::runtime_error(
"Unsupported input data type for quantize_to_fp4.");
}
}
} }
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#define ELTS_PER_THREAD 8
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
namespace vllm {
// Convert PyTorch cpp type to CUDA type
template <typename T>
struct CUDATypeConverter {
using Type = T;
};
template <>
struct CUDATypeConverter<at::Half> {
using Type = half;
};
template <>
struct CUDATypeConverter<at::BFloat16> {
using Type = __nv_bfloat16;
};
// Get type2 from type or vice versa (applied to half and bfloat16)
template <typename T>
struct TypeConverter {
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2> {
using Type = half;
};
template <>
struct TypeConverter<half> {
using Type = half2;
};
template <>
struct TypeConverter<__nv_bfloat162> {
using Type = __nv_bfloat16;
};
template <>
struct TypeConverter<__nv_bfloat16> {
using Type = __nv_bfloat162;
};
// Define a 16 bytes packed data type.
template <class Type>
struct PackedVec {
typename TypeConverter<Type>::Type elts[4];
};
template <>
struct PackedVec<__nv_fp8_e4m3> {
__nv_fp8x2_e4m3 elts[8];
};
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]),
"f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7]));
return val;
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
return val;
}
// Fast reciprocal.
inline __device__ float reciprocal_approximate_ftz(float a) {
float b;
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
return b;
}
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
int numCols,
SFType* SFout) {
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
CVT_FP4_NUM_THREADS_PER_SF == 2);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
// SF vector index (16 elements share one SF in the K dimension).
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
int32_t mIdx = rowIdx;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t mTileIdx = mIdx / (32 * 4);
// SF vector size 16.
int factor = CVT_FP4_SF_VEC_SIZE * 4;
int32_t numKTiles = (numCols + factor - 1) / factor;
int64_t mTileStride = numKTiles * 32 * 4 * 4;
int32_t kTileIdx = (kIdx / 4);
int64_t kTileStride = 32 * 4 * 4;
// M tile layout [32, 4] is column-major.
int32_t outerMIdx = (mIdx % 32);
int64_t outerMStride = 4 * 4;
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
int64_t innerMStride = 4;
int32_t innerKIdx = (kIdx % 4);
int64_t innerKStride = 1;
// Compute the global offset.
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
outerMIdx * outerMStride + innerMIdx * innerMStride +
innerKIdx * innerKStride;
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
}
return nullptr;
}
// Quantizes the provided PackedVec into the uint32_t output
template <class Type, bool UE8M0_SF = false>
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
uint8_t* SFout) {
// Get absolute maximum values among the local 8 values.
auto localMax = __habs2(vec.elts[0]);
// Local maximum value.
#pragma unroll
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
// Get the final absolute maximum values.
float vecMax = float(__hmax(localMax.x, localMax.y));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
// 8 bits representation of the SF.
uint8_t fp8SFVal;
// Write the SF to global memory (STG.8).
if constexpr (UE8M0_SF) {
// Extract the 8 exponent bits from float32.
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
fp8SFVal = tmp & 0xff;
// Convert back to fp32.
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
} else {
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
// Convert back to fp32.
SFValue = float(tmp);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float outputScale =
SFValue != 0 ? reciprocal_approximate_ftz(
SFValue * reciprocal_approximate_ftz(SFScaleVal))
: 0.0f;
if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}
// Convert the input to float.
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) {
fp2Vals[i] = __half22float2(vec.elts[i]);
} else {
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}
// Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
// Write the e2m1 values to global memory.
return e2m1Vec;
}
} // namespace vllm
...@@ -417,7 +417,7 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8): ...@@ -417,7 +417,7 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
)) ))
def prepacked_type_key(prepack_type: PrepackTypeConfig): def prepacked_type_key(prepack_type: PrepackTypeConfig):
# For now we we can just use the first accumulator type seen since # For now, we can just use the first accumulator type seen since
# the tensor core shapes/layouts don't vary based on accumulator # the tensor core shapes/layouts don't vary based on accumulator
# type so we can generate less code this way # type so we can generate less code this way
return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert) return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert)
......
...@@ -115,8 +115,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -115,8 +115,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); // "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
// ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); // ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ #ifndef USE_ROCM
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
ops.def( ops.def(
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, " "silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
"Tensor input, Tensor input_global_scale) -> ()"); "Tensor input, Tensor input_global_scale) -> ()");
...@@ -169,6 +168,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -169,6 +168,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()"); "float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
// Polynomial Normalization.
ops.def(
"poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, float "
"epsilon) -> ()");
ops.impl("poly_norm", torch::kCUDA, &poly_norm);
// Apply repetition penalties to logits in-place // Apply repetition penalties to logits in-place
ops.def( ops.def(
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
...@@ -521,10 +526,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -521,10 +526,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// SM100 CUTLASS MLA decode // SM100 CUTLASS MLA decode
ops.def( ops.def(
"sm100_cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," "sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope,"
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens," " Tensor q_pe, Tensor kv_c_and_k_pe_cache,"
" Tensor page_table, Tensor workspace, float " " Tensor seq_lens, Tensor page_table,"
"scale," " Tensor workspace, float scale,"
" int num_kv_splits) -> ()"); " int num_kv_splits) -> ()");
// conditionally compiled so impl in source file // conditionally compiled so impl in source file
...@@ -698,16 +703,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { ...@@ -698,16 +703,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor scale) -> ()"); " Tensor scale) -> ()");
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla); cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
cache_ops.def(
"cp_fused_concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor cp_local_token_select_indices,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()");
cache_ops.impl("cp_fused_concat_and_cache_mla", torch::kCUDA,
&cp_fused_concat_and_cache_mla);
// Convert the key and value cache to fp8 data type. // Convert the key and value cache to fp8 data type.
cache_ops.def( cache_ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
......
...@@ -237,7 +237,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ ...@@ -237,7 +237,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
# Check the size of the wheel if RUN_WHEEL_CHECK is true # Check the size of the wheel if RUN_WHEEL_CHECK is true
COPY .buildkite/check-wheel-size.py check-wheel-size.py COPY .buildkite/check-wheel-size.py check-wheel-size.py
# sync the default value with .buildkite/check-wheel-size.py # sync the default value with .buildkite/check-wheel-size.py
ARG VLLM_MAX_SIZE_MB=400 ARG VLLM_MAX_SIZE_MB=450
ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB
ARG RUN_WHEEL_CHECK=true ARG RUN_WHEEL_CHECK=true
RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \ RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
...@@ -261,6 +261,8 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match" ...@@ -261,6 +261,8 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match"
# Use copy mode to avoid hardlink failures with Docker cache mounts # Use copy mode to avoid hardlink failures with Docker cache mounts
ENV UV_LINK_MODE=copy ENV UV_LINK_MODE=copy
# Install libnuma-dev, required by fastsafetensors (fixes #20384)
RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/*
COPY requirements/lint.txt requirements/lint.txt COPY requirements/lint.txt requirements/lint.txt
COPY requirements/test.txt requirements/test.txt COPY requirements/test.txt requirements/test.txt
COPY requirements/dev.txt requirements/dev.txt COPY requirements/dev.txt requirements/dev.txt
...@@ -373,7 +375,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist ...@@ -373,7 +375,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
# Install FlashInfer from source # Install FlashInfer from source
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
# Keep this in sync with "flashinfer" extra in setup.py # Keep this in sync with "flashinfer" extra in setup.py
ARG FLASHINFER_GIT_REF="v0.2.14.post1" ARG FLASHINFER_GIT_REF="v0.3.0"
# Flag to control whether to compile FlashInfer AOT kernels # Flag to control whether to compile FlashInfer AOT kernels
# Set to "true" to enable AOT compilation: # Set to "true" to enable AOT compilation:
# docker build --build-arg FLASHINFER_AOT_COMPILE=true ... # docker build --build-arg FLASHINFER_AOT_COMPILE=true ...
...@@ -432,11 +434,10 @@ RUN --mount=type=cache,target=/root/.cache/uv \ ...@@ -432,11 +434,10 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
# Install DeepGEMM from source # Install DeepGEMM from source
ARG DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c" ARG DEEPGEMM_GIT_REF
COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
VLLM_DOCKER_BUILD_CONTEXT=1 /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" --ref "${DEEPGEMM_GIT_REF}" \ VLLM_DOCKER_BUILD_CONTEXT=1 /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" ${DEEPGEMM_GIT_REF:+--ref "$DEEPGEMM_GIT_REF"}
&& rm /tmp/install_deepgemm.sh
# Install EP kernels(pplx-kernels and DeepEP), NixL # Install EP kernels(pplx-kernels and DeepEP), NixL
COPY tools/ep_kernels/install_python_libraries.sh install_python_libraries.sh COPY tools/ep_kernels/install_python_libraries.sh install_python_libraries.sh
...@@ -518,7 +519,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ ...@@ -518,7 +519,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
else \ else \
BITSANDBYTES_VERSION="0.46.1"; \ BITSANDBYTES_VERSION="0.46.1"; \
fi; \ fi; \
uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3] uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' boto3 runai-model-streamer runai-model-streamer[s3]
ENV VLLM_USAGE_SOURCE production-docker-image ENV VLLM_USAGE_SOURCE production-docker-image
......
# default base image
# https://gallery.ecr.aws/neuron/pytorch-inference-neuronx
ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.6.0-neuronx-py310-sdk2.23.0-ubuntu22.04"
FROM $BASE_IMAGE
RUN echo "Base image is $BASE_IMAGE"
# Install some basic utilities
RUN apt-get update && \
apt-get install -y \
git \
python3 \
python3-pip \
ffmpeg libsm6 libxext6 libgl1
### Mount Point ###
# When launching the container, mount the code directory to /workspace
ARG APP_MOUNT=/workspace
VOLUME [ ${APP_MOUNT} ]
WORKDIR ${APP_MOUNT}/vllm
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas tenacity
RUN python3 -m pip install neuronx-cc==2.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
RUN python3 -m pip install pytest
# uninstall transformers-neuronx package explicitly to avoid version conflict
RUN python3 -m pip uninstall -y transformers-neuronx
COPY . .
ARG GIT_REPO_CHECK=0
RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
RUN python3 -m pip install -U \
'cmake>=3.26.1' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
-r requirements/neuron.txt
ENV VLLM_TARGET_DEVICE neuron
RUN --mount=type=bind,source=.git,target=.git \
pip install --no-build-isolation -v -e .
# install development dependencies (for testing)
RUN python3 -m pip install -e tests/vllm_test_utils
# install transformers-neuronx package as an optional dependencies (for V0)
# FIXME: `--no-deps` argument is temporarily added to resolve transformers package version conflict
RUN python3 -m pip install transformers-neuronx==0.13.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U --no-deps
RUN python3 -m pip install sentencepiece transformers==4.48.0 -U
# overwrite entrypoint to run bash script
RUN echo "import subprocess; import sys; subprocess.check_call(sys.argv[1:])" > /usr/local/bin/dockerd-entrypoint.py
CMD ["/bin/bash"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment