Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
38d80967
Commit
38d80967
authored
Sep 12, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori
parents
33650733
880c741b
Changes
544
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
807 additions
and
972 deletions
+807
-972
csrc/dispatch_utils.h
csrc/dispatch_utils.h
+0
-9
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+251
-0
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
+47
-25
csrc/moe/grouped_topk_kernels.cu
csrc/moe/grouped_topk_kernels.cu
+7
-6
csrc/ops.h
csrc/ops.h
+5
-3
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
+10
-4
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
...tlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
+0
-3
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh
...tlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh
+0
-3
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
...utlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
+90
-112
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
+1
-27
csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu
...quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu
+28
-184
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
+16
-0
csrc/quantization/fp4/nvfp4_experts_quant.cu
csrc/quantization/fp4/nvfp4_experts_quant.cu
+47
-263
csrc/quantization/fp4/nvfp4_quant_entry.cu
csrc/quantization/fp4/nvfp4_quant_entry.cu
+18
-0
csrc/quantization/fp4/nvfp4_quant_kernels.cu
csrc/quantization/fp4/nvfp4_quant_kernels.cu
+17
-254
csrc/quantization/fp4/nvfp4_utils.cuh
csrc/quantization/fp4/nvfp4_utils.cuh
+251
-0
csrc/quantization/machete/generate.py
csrc/quantization/machete/generate.py
+1
-1
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+11
-16
docker/Dockerfile
docker/Dockerfile
+7
-6
docker/Dockerfile.neuron
docker/Dockerfile.neuron
+0
-56
No files found.
Too many changes to show.
To preserve performance only
544 of 544+
files are displayed.
Plain diff
Email patch
csrc/dispatch_utils.h
View file @
38d80967
...
...
@@ -52,15 +52,6 @@
#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
#define AT_DISPATCH_BYTE_CASE(enum_type, ...) \
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, byte_t, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_BYTE_TYPES(...) \
AT_DISPATCH_BYTE_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_BYTE_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
...
...
csrc/layernorm_kernels.cu
View file @
38d80967
...
...
@@ -140,6 +140,211 @@ fused_add_rms_norm_kernel(
}
}
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck.
_f16VecPN struct extends _f16Vec to add operations specifically required for
polynomial normalization (poly norm).
The original _f16Vec does not include the sum-of-powers computation or
in-place polynomial normalization logic. */
template
<
typename
scalar_t
,
int
width
>
struct
alignas
(
16
)
_f16VecPN
:
_f16Vec
<
scalar_t
,
width
>
{
using
Base
=
_f16Vec
<
scalar_t
,
width
>
;
using
Converter
=
typename
Base
::
Converter
;
using
T1
=
typename
Base
::
T1
;
using
T2
=
typename
Base
::
T2
;
using
Base
::
data
;
__device__
auto
sum_pows
()
const
{
float
s2
=
0.0
f
,
s4
=
0.0
f
,
s6
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
z
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
float
x2
=
z
.
x
*
z
.
x
;
float
x4
=
x2
*
x2
;
float
x6
=
x4
*
x2
;
float
y2
=
z
.
y
*
z
.
y
;
float
y4
=
y2
*
y2
;
float
y6
=
y4
*
y2
;
s2
+=
x2
+
y2
;
s4
+=
x4
+
y4
;
s6
+=
x6
+
y6
;
}
return
std
::
make_tuple
(
s2
,
s4
,
s6
);
}
__device__
void
poly_norm_inplace
(
const
float
w2_inv_std
,
const
float
w1_inv_std2
,
const
float
w0_inv_std3
,
const
float
bias
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
z
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
float
x2
=
z
.
x
*
z
.
x
;
float
x3
=
x2
*
z
.
x
;
z
.
x
=
w2_inv_std
*
z
.
x
+
w1_inv_std2
*
x2
+
w0_inv_std3
*
x3
+
bias
;
float
y2
=
z
.
y
*
z
.
y
;
float
y3
=
y2
*
z
.
y
;
z
.
y
=
w2_inv_std
*
z
.
y
+
w1_inv_std2
*
y2
+
w0_inv_std3
*
y3
+
bias
;
auto
out
=
Converter
::
convert
(
z
);
data
[
i
]
=
out
.
x
;
data
[
i
+
1
]
=
out
.
y
;
}
}
};
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
poly_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [3]
const
scalar_t
*
__restrict__
bias
,
// [1]
const
float
epsilon
,
const
int
hidden_size
)
{
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert
(
std
::
is_pod_v
<
_f16VecPN
<
scalar_t
,
width
>>
);
static_assert
(
sizeof
(
_f16VecPN
<
scalar_t
,
width
>
)
==
sizeof
(
scalar_t
)
*
width
);
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto
*
__restrict__
input_v
=
reinterpret_cast
<
const
_f16VecPN
<
scalar_t
,
width
>*>
(
input
);
const
int
vec_hidden_size
=
hidden_size
/
width
;
float
variance
=
0.0
f
;
float
variance2
=
0.0
f
;
float
variance3
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
_f16VecPN
<
scalar_t
,
width
>
temp
=
input_v
[
id
];
auto
[
x2
,
x4
,
x6
]
=
temp
.
sum_pows
();
variance
+=
x2
;
variance2
+=
x4
;
variance3
+=
x6
;
}
float3
thread_variances
=
make_float3
(
variance
,
variance2
,
variance3
);
struct
SumOp
{
__device__
float3
operator
()(
const
float3
&
a
,
const
float3
&
b
)
const
{
return
make_float3
(
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
,
a
.
z
+
b
.
z
);
}
};
using
BlockReduce
=
cub
::
BlockReduce
<
float3
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
float3
block_variances
=
BlockReduce
(
reduceStore
).
Reduce
(
thread_variances
,
SumOp
{},
blockDim
.
x
);
variance
=
block_variances
.
x
;
variance2
=
block_variances
.
y
;
variance3
=
block_variances
.
z
;
__shared__
float
s_w2_inv_std
;
__shared__
float
s_w1_inv_std2
;
__shared__
float
s_w0_inv_std3
;
__shared__
float
s_bias
;
if
(
threadIdx
.
x
==
0
)
{
float
w0
=
(
float
)
weight
[
0
];
float
w1
=
(
float
)
weight
[
1
];
float
w2
=
(
float
)
weight
[
2
];
s_bias
=
(
float
)
bias
[
0
];
s_w2_inv_std
=
w2
*
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_w1_inv_std2
=
w1
*
rsqrtf
(
variance2
/
hidden_size
+
epsilon
);
s_w0_inv_std3
=
w0
*
rsqrtf
(
variance3
/
hidden_size
+
epsilon
);
}
__syncthreads
();
auto
*
__restrict__
out_v
=
reinterpret_cast
<
_f16VecPN
<
scalar_t
,
width
>*>
(
out
);
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
_f16VecPN
<
scalar_t
,
width
>
temp
=
input_v
[
id
];
temp
.
poly_norm_inplace
(
s_w2_inv_std
,
s_w1_inv_std2
,
s_w0_inv_std3
,
s_bias
);
out_v
[
id
]
=
temp
;
}
}
/* Generic poly_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
poly_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [3]
const
scalar_t
*
__restrict__
bias
,
// [1]
const
float
epsilon
,
const
int
hidden_size
)
{
float
variance
=
0.0
f
;
float
variance2
=
0.0
f
;
float
variance3
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x2
=
x
*
x
;
float
x4
=
x2
*
x2
;
float
x6
=
x4
*
x2
;
variance
+=
x2
;
variance2
+=
x4
;
variance3
+=
x6
;
}
float3
thread_variances
=
make_float3
(
variance
,
variance2
,
variance3
);
struct
SumOp
{
__device__
float3
operator
()(
const
float3
&
a
,
const
float3
&
b
)
const
{
return
make_float3
(
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
,
a
.
z
+
b
.
z
);
}
};
using
BlockReduce
=
cub
::
BlockReduce
<
float3
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
float3
block_variances
=
BlockReduce
(
reduceStore
).
Reduce
(
thread_variances
,
SumOp
{},
blockDim
.
x
);
variance
=
block_variances
.
x
;
variance2
=
block_variances
.
y
;
variance3
=
block_variances
.
z
;
__shared__
float
s_w2_inv_std
;
__shared__
float
s_w1_inv_std2
;
__shared__
float
s_w0_inv_std3
;
__shared__
float
s_bias
;
if
(
threadIdx
.
x
==
0
)
{
float
w0
=
(
float
)
weight
[
0
];
float
w1
=
(
float
)
weight
[
1
];
float
w2
=
(
float
)
weight
[
2
];
s_bias
=
(
float
)
bias
[
0
];
s_w2_inv_std
=
w2
*
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_w1_inv_std2
=
w1
*
rsqrtf
(
variance2
/
hidden_size
+
epsilon
);
s_w0_inv_std3
=
w0
*
rsqrtf
(
variance3
/
hidden_size
+
epsilon
);
}
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x2
=
x
*
x
;
float
x3
=
x2
*
x
;
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
(
scalar_t
)(
x
*
s_w2_inv_std
+
x2
*
s_w1_inv_std2
+
x3
*
s_w0_inv_std3
+
s_bias
);
}
}
}
// namespace vllm
void
rms_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
...
...
@@ -219,3 +424,49 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
}
}
#define LAUNCH_FUSED_POLY_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \
vllm::poly_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), epsilon, \
hidden_size); \
});
void
poly_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [3]
torch
::
Tensor
&
bias
,
// [1]
double
epsilon
)
{
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
data_ptr
()
!=
input
.
data_ptr
());
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
grid
(
num_tokens
);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const
int
max_block_size
=
(
num_tokens
<
256
)
?
1024
:
256
;
dim3
block
(
std
::
min
(
hidden_size
,
max_block_size
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
out_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
out
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
out_ptr
%
16
==
0
;
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
)
{
LAUNCH_FUSED_POLY_NORM
(
8
);
}
else
{
LAUNCH_FUSED_POLY_NORM
(
0
);
}
}
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
View file @
38d80967
...
...
@@ -27,11 +27,12 @@
template
<
int
kNThreads_
,
int
kNItems_
,
int
kNRows_
,
bool
kIsEvenLen_
,
bool
kIsVariableB_
,
bool
kIsVariableC_
,
bool
kHasZ_
,
bool
kVarlen_
,
typename
input_t_
,
typename
weight_t_
>
bool
kHasZ_
,
bool
kVarlen_
,
typename
input_t_
,
typename
weight_t_
,
typename
state_t_
>
struct
Selective_Scan_fwd_kernel_traits
{
static_assert
(
kNItems_
%
4
==
0
);
using
input_t
=
input_t_
;
using
weight_t
=
weight_t_
;
using
state_t
=
state_t_
;
static
constexpr
int
kNThreads
=
kNThreads_
;
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
static
constexpr
int
kMinBlocks
=
kNThreads
<
128
?
5
:
3
;
...
...
@@ -132,7 +133,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
input_t
*
Bvar
=
reinterpret_cast
<
input_t
*>
(
params
.
B_ptr
)
+
sequence_start_index
*
params
.
B_batch_stride
+
group_id
*
params
.
B_group_stride
;
weight_t
*
C
=
reinterpret_cast
<
weight_t
*>
(
params
.
C_ptr
)
+
dim_id
*
kNRows
*
params
.
C_d_stride
;
input_t
*
Cvar
=
reinterpret_cast
<
input_t
*>
(
params
.
C_ptr
)
+
sequence_start_index
*
params
.
C_batch_stride
+
group_id
*
params
.
C_group_stride
;
input
_t
*
ssm_states
=
reinterpret_cast
<
input
_t
*>
(
params
.
ssm_states_ptr
)
+
typename
Ktraits
::
state
_t
*
ssm_states
=
reinterpret_cast
<
typename
Ktraits
::
state
_t
*>
(
params
.
ssm_states_ptr
)
+
cache_index
*
params
.
ssm_states_batch_stride
+
dim_id
*
kNRows
*
params
.
ssm_states_dim_stride
;
...
...
@@ -261,7 +262,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if
(
threadIdx
.
x
==
0
)
{
smem_running_prefix
[
state_idx
]
=
prefix_op
.
running_prefix
;
if
(
chunk
==
n_chunks
-
1
)
{
ssm_states
[
state_idx
*
params
.
ssm_states_dstate_stride
]
=
input
_t
(
prefix_op
.
running_prefix
.
y
);
ssm_states
[
state_idx
*
params
.
ssm_states_dstate_stride
]
=
typename
Ktraits
::
state
_t
(
prefix_op
.
running_prefix
.
y
);
}
}
#pragma unroll
...
...
@@ -310,7 +311,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
template
<
int
kNThreads
,
int
kNItems
,
typename
input_t
,
typename
weight_t
>
template
<
int
kNThreads
,
int
kNItems
,
typename
input_t
,
typename
weight_t
,
typename
state_t
>
void
selective_scan_fwd_launch
(
SSMParamsBase
&
params
,
cudaStream_t
stream
)
{
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
// processing 1 row.
...
...
@@ -321,7 +322,7 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
BOOL_SWITCH
(
params
.
seqlen
%
(
kNThreads
*
kNItems
)
==
0
,
kIsEvenLen
,
[
&
]
{
BOOL_SWITCH
(
params
.
z_ptr
!=
nullptr
,
kHasZ
,
[
&
]
{
BOOL_SWITCH
(
params
.
query_start_loc_ptr
!=
nullptr
,
kVarlen
,
[
&
]
{
using
Ktraits
=
Selective_Scan_fwd_kernel_traits
<
kNThreads
,
kNItems
,
kNRows
,
kIsEvenLen
,
kIsVariableB
,
kIsVariableC
,
kHasZ
,
kVarlen
,
input_t
,
weight_t
>
;
using
Ktraits
=
Selective_Scan_fwd_kernel_traits
<
kNThreads
,
kNItems
,
kNRows
,
kIsEvenLen
,
kIsVariableB
,
kIsVariableC
,
kHasZ
,
kVarlen
,
input_t
,
weight_t
,
state_t
>
;
constexpr
int
kSmemSize
=
Ktraits
::
kSmemSize
+
kNRows
*
MAX_DSTATE
*
sizeof
(
typename
Ktraits
::
scan_t
);
dim3
grid
(
params
.
batch
,
params
.
dim
/
kNRows
);
auto
kernel
=
&
selective_scan_fwd_kernel
<
Ktraits
>
;
...
...
@@ -341,59 +342,78 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
});
}
template
<
typename
input_t
,
typename
weight_t
>
template
<
typename
input_t
,
typename
weight_t
,
typename
state_t
>
void
selective_scan_fwd_cuda
(
SSMParamsBase
&
params
,
cudaStream_t
stream
)
{
#ifndef USE_ROCM
if
(
params
.
seqlen
<=
128
)
{
selective_scan_fwd_launch
<
32
,
4
,
input_t
,
weight_t
>
(
params
,
stream
);
selective_scan_fwd_launch
<
32
,
4
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
256
)
{
selective_scan_fwd_launch
<
32
,
8
,
input_t
,
weight_t
>
(
params
,
stream
);
selective_scan_fwd_launch
<
32
,
8
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
512
)
{
selective_scan_fwd_launch
<
32
,
16
,
input_t
,
weight_t
>
(
params
,
stream
);
selective_scan_fwd_launch
<
32
,
16
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
1024
)
{
selective_scan_fwd_launch
<
64
,
16
,
input_t
,
weight_t
>
(
params
,
stream
);
selective_scan_fwd_launch
<
64
,
16
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
}
else
{
selective_scan_fwd_launch
<
128
,
16
,
input_t
,
weight_t
>
(
params
,
stream
);
selective_scan_fwd_launch
<
128
,
16
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
}
#else
if
(
params
.
seqlen
<=
256
)
{
selective_scan_fwd_launch
<
64
,
4
,
input_t
,
weight_t
>
(
params
,
stream
);
selective_scan_fwd_launch
<
64
,
4
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
512
)
{
selective_scan_fwd_launch
<
64
,
8
,
input_t
,
weight_t
>
(
params
,
stream
);
selective_scan_fwd_launch
<
64
,
8
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
1024
)
{
selective_scan_fwd_launch
<
64
,
16
,
input_t
,
weight_t
>
(
params
,
stream
);
selective_scan_fwd_launch
<
64
,
16
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
}
else
{
selective_scan_fwd_launch
<
128
,
16
,
input_t
,
weight_t
>
(
params
,
stream
);
selective_scan_fwd_launch
<
128
,
16
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
}
#endif
}
template
void
selective_scan_fwd_cuda
<
at
::
BFloat16
,
float
>(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
template
void
selective_scan_fwd_cuda
<
at
::
Half
,
float
>(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
template
void
selective_scan_fwd_cuda
<
float
,
float
>(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
template
void
selective_scan_fwd_cuda
<
at
::
BFloat16
,
float
,
at
::
BFloat16
>(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
template
void
selective_scan_fwd_cuda
<
at
::
BFloat16
,
float
,
float
>(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
template
void
selective_scan_fwd_cuda
<
at
::
Half
,
float
,
at
::
Half
>(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
template
void
selective_scan_fwd_cuda
<
at
::
Half
,
float
,
float
>(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
template
void
selective_scan_fwd_cuda
<
float
,
float
,
float
>(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...)
\
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE,
STYPE,
NAME, ...) \
if (ITYPE == at::ScalarType::Half) { \
using input_t = at::Half; \
using weight_t = float; \
__VA_ARGS__(); \
if (STYPE == at::ScalarType::Half) { \
using state_t = at::Half; \
__VA_ARGS__(); \
} else if (STYPE == at::ScalarType::Float) { \
using state_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
} \
} else if (ITYPE == at::ScalarType::BFloat16) { \
using input_t = at::BFloat16; \
using weight_t = float; \
__VA_ARGS__(); \
if (STYPE == at::ScalarType::BFloat16) { \
using state_t = at::BFloat16; \
__VA_ARGS__(); \
} else if (STYPE == at::ScalarType::Float) { \
using state_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
} \
} else if (ITYPE == at::ScalarType::Float) { \
using input_t = float; \
using weight_t = float; \
using state_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
}
template
<
typename
input_t
,
typename
weight_t
>
template
<
typename
input_t
,
typename
weight_t
,
typename
state_t
>
void
selective_scan_fwd_cuda
(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
void
set_ssm_params_fwd
(
SSMParamsBase
&
params
,
...
...
@@ -648,7 +668,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
at
::
Tensor
out
=
delta
;
TORCH_CHECK
(
ssm_states
.
scalar_type
()
==
input_type
);
// ssm_states can now be either the same as input_type or float32
auto
state_type
=
ssm_states
.
scalar_type
();
TORCH_CHECK
(
state_type
==
input_type
||
state_type
==
at
::
ScalarType
::
Float
);
TORCH_CHECK
(
ssm_states
.
is_cuda
());
TORCH_CHECK
(
ssm_states
.
stride
(
-
1
)
==
1
);
...
...
@@ -670,7 +692,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
u
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16
(
u
.
scalar_type
(),
"selective_scan_fwd"
,
[
&
]
{
selective_scan_fwd_cuda
<
input_t
,
weight_t
>
(
params
,
stream
);
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16
(
u
.
scalar_type
(),
ssm_states
.
scalar_type
(),
"selective_scan_fwd"
,
[
&
]
{
selective_scan_fwd_cuda
<
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
});
}
csrc/moe/grouped_topk_kernels.cu
View file @
38d80967
...
...
@@ -28,6 +28,7 @@ namespace cg = cooperative_groups;
namespace
vllm
{
namespace
moe
{
constexpr
float
kNegInfinity
=
INFINITY
*
-
1
;
constexpr
unsigned
FULL_WARP_MASK
=
0xffffffff
;
constexpr
int32_t
WARP_SIZE
=
32
;
constexpr
int32_t
BLOCK_SIZE
=
512
;
...
...
@@ -512,8 +513,8 @@ __global__ void group_idx_and_topk_idx_kernel(
warp_id
*
topk
;
s_topk_idx
+=
warp_id
*
topk
;
T
value
=
cuda
::
std
::
numeric_limits
<
T
>::
min
()
;
T
topk_group_value
=
cuda
::
std
::
numeric_limits
<
T
>::
min
()
;
T
value
=
kNegInfinity
;
T
topk_group_value
=
kNegInfinity
;
int32_t
num_equalto_topkth_group
;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
...
...
@@ -539,11 +540,11 @@ __global__ void group_idx_and_topk_idx_kernel(
__syncwarp
();
// Ensure all threads have valid data before reduction
topk_group_value
=
cg
::
reduce
(
tile
,
value
,
cg
::
greater
<
T
>
());
if
(
value
==
topk_group_value
)
{
value
=
cuda
::
std
::
numeric_limits
<
T
>::
min
()
;
value
=
kNegInfinity
;
}
pre_count_equal_to_top_value
=
count_equal_to_top_value
;
count_equal_to_top_value
=
__popc
(
__ballot_sync
(
FULL_WARP_MASK
,
(
value
==
cuda
::
std
::
numeric_limits
<
T
>::
min
(
))));
FULL_WARP_MASK
,
(
value
==
cuda
_cast
<
T
,
float
>
(
kNegInfinity
))));
}
num_equalto_topkth_group
=
target_num_min
-
pre_count_equal_to_top_value
;
}
...
...
@@ -555,7 +556,7 @@ __global__ void group_idx_and_topk_idx_kernel(
int
count_equalto_topkth_group
=
0
;
bool
if_proceed_next_topk
=
(
topk_group_value
!=
cuda
::
std
::
numeric_limits
<
T
>::
min
(
));
(
topk_group_value
!=
cuda
_cast
<
T
,
float
>
(
kNegInfinity
));
if
(
case_id
<
num_tokens
&&
if_proceed_next_topk
)
{
for
(
int
i_group
=
0
;
i_group
<
n_group
;
i_group
++
)
{
if
((
group_scores
[
i_group
]
>
topk_group_value
)
||
...
...
@@ -568,7 +569,7 @@ __global__ void group_idx_and_topk_idx_kernel(
(
i
<
num_experts_per_group
)
&&
isfinite
(
cuda_cast
<
float
,
T
>
(
scores_with_bias
[
offset
+
i
]))
?
scores_with_bias
[
offset
+
i
]
:
cuda
::
std
::
numeric_limits
<
T
>::
min
(
);
:
cuda
_cast
<
T
,
float
>
(
kNegInfinity
);
queue
.
add
(
candidates
,
offset
+
i
);
}
if
(
group_scores
[
i_group
]
==
topk_group_value
)
{
...
...
csrc/ops.h
View file @
38d80967
...
...
@@ -92,6 +92,9 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
double
epsilon
);
void
poly_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
torch
::
Tensor
&
bias
,
double
epsilon
);
void
apply_repetition_penalties_
(
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
prompt_mask
,
const
torch
::
Tensor
&
output_mask
,
...
...
@@ -130,8 +133,7 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
// void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& scale);
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
#ifndef USE_ROCM
void
silu_and_mul_nvfp4_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
output_block_scale
,
torch
::
Tensor
&
input
,
...
...
@@ -356,4 +358,4 @@ void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
void
qr_all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
int64_t
quant_level
,
bool
cast_bf2half
=
false
);
int64_t
qr_max_size
();
#endif
\ No newline at end of file
#endif
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
View file @
38d80967
...
...
@@ -11,6 +11,7 @@
#include "core/registration.h"
#include "cutlass/cutlass.h"
#include <limits>
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
...
...
@@ -169,6 +170,11 @@ struct W4A8GemmKernel {
int
k
=
A
.
size
(
1
);
int
n
=
B
.
size
(
1
);
// safely cast group_size to int
TORCH_CHECK
(
group_size
>
0
&&
group_size
<=
std
::
numeric_limits
<
int
>::
max
(),
"group_size out of supported range for int: "
,
group_size
);
int
const
group_size_int
=
static_cast
<
int
>
(
group_size
);
// Allocate output
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
A
));
auto
device
=
A
.
device
();
...
...
@@ -181,7 +187,7 @@ struct W4A8GemmKernel {
auto
A_ptr
=
static_cast
<
MmaType
const
*>
(
A
.
const_data_ptr
());
auto
B_ptr
=
static_cast
<
QuantType
const
*>
(
B
.
const_data_ptr
());
auto
D_ptr
=
static_cast
<
ElementD
*>
(
D
.
data_ptr
());
// can we avoid harcode the 8 here
// can we avoid har
d
code the 8 here
auto
S_ptr
=
static_cast
<
cutlass
::
Array
<
ElementScale
,
ScalePackSize
>
const
*>
(
group_scales
.
const_data_ptr
());
...
...
@@ -192,7 +198,7 @@ struct W4A8GemmKernel {
cute
::
tile_to_shape
(
LayoutAtomQuant
{},
shape_B
);
// strides
int
const
scale_k
=
cutlass
::
ceil_div
(
k
,
group_size
);
int
const
scale_k
=
cutlass
::
ceil_div
(
k
,
group_size
_int
);
StrideA
stride_A
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
cute
::
make_shape
(
m
,
k
,
1
));
// Reverse stride here due to swap and transpose
...
...
@@ -211,8 +217,8 @@ struct W4A8GemmKernel {
using
EpilogueArguments
=
typename
GemmKernelShuffled
::
EpilogueArguments
;
MainloopArguments
mainloop_arguments
{
B_ptr
,
layout_B_reordered
,
A_ptr
,
stride_A
,
S_ptr
,
stride_S
,
group_size
};
B_ptr
,
layout_B_reordered
,
A_ptr
,
stride_A
,
S_ptr
,
stride_S
,
group_size
_int
};
EpilogueArguments
epilogue_arguments
{
ChTokScalesEpilogue
::
prepare_args
(
channel_scales
,
token_scales
),
...
...
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
View file @
38d80967
...
...
@@ -14,9 +14,6 @@
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh"
namespace
vllm
{
...
...
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh
View file @
38d80967
...
...
@@ -14,9 +14,6 @@
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh"
namespace
vllm
{
...
...
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
View file @
38d80967
...
...
@@ -13,27 +13,18 @@
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh"
namespace
vllm
{
using
namespace
cute
;
template
<
typename
SchedulerType
,
typename
OutType
,
int
GroupSizeM_
,
int
GroupSizeN_
,
int
GroupSizeK_
,
int
TileSizeM_
=
128
,
class
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
>
// clang-format off
template
<
class
OutType
,
int
ScaleGranularityM
,
int
ScaleGranularityN
,
int
ScaleGranularityK
,
class
MmaTileShape
,
class
ClusterShape
,
class
EpilogueScheduler
,
class
MainloopScheduler
>
struct
cutlass_3x_gemm_fp8_blockwise
{
using
GroupSizeM
=
Int
<
GroupSizeM_
>
;
using
GroupSizeN
=
Int
<
GroupSizeN_
>
;
using
GroupSizeK
=
Int
<
GroupSizeK_
>
;
using
TileSizeM
=
Int
<
TileSizeM_
>
;
static_assert
(
TileSizeM_
%
GroupSizeM_
==
0
,
"TileSizeM must be a multiple of GroupSizeM"
);
using
ElementAB
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
ElementAB
;
...
...
@@ -45,52 +36,67 @@ struct cutlass_3x_gemm_fp8_blockwise {
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
using
ElementD
=
OutType
;
using
StrideD
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
using
ElementC
=
void
;
using
StrideC
=
Stride
D
;
using
ElementC
=
void
;
// TODO: support bias
using
LayoutC
=
Layout
D
;
static
constexpr
int
AlignmentC
=
AlignmentD
;
using
ElementAccumulator
=
float
;
using
ElementBlockScale
=
float
;
using
ElementCompute
=
float
;
using
ElementBlockScale
=
float
;
using
ScaleConfig
=
cutlass
::
detail
::
Sm90BlockwiseScaleConfig
<
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
>
;
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
TileShape
=
Shape
<
TileSizeM
,
GroupSizeN
,
GroupSizeK
>
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
GroupSizeM_
>
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
EpilogueTileType
=
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
;
using
StoreEpilogueCompute
=
typename
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
TileShape
,
ClusterShape
,
EpilogueTileType
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
StrideC
,
AlignmentC
,
ElementD
,
StrideD
,
AlignmentD
,
EpilogueSchedule
,
StoreEpilogueCompute
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutA
,
AlignmentA
,
ElementB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
KernelSchedule
>::
CollectiveOp
;
static
constexpr
auto
RoundStyle
=
cutlass
::
FloatRoundStyle
::
round_to_nearest
;
using
ElementScalar
=
float
;
using
DefaultOperation
=
cutlass
::
epilogue
::
fusion
::
LinearCombination
<
ElementD
,
ElementCompute
,
ElementC
,
ElementScalar
,
RoundStyle
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
MmaTileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementD
,
LayoutD
,
AlignmentD
,
EpilogueScheduler
,
DefaultOperation
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
cute
::
tuple
<
LayoutA
,
LayoutSFA
>
,
AlignmentA
,
ElementB
,
cute
::
tuple
<
LayoutB
,
LayoutSFB
>
,
AlignmentB
,
ElementAccumulator
,
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
MainloopScheduler
>::
CollectiveOp
;
using
KernelType
=
enable_sm90_or_later
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
SchedulerType
>>
;
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
>>
;
struct
GemmKernel
:
public
KernelType
{};
using
StrideA
=
typename
GemmKernel
::
StrideA
;
using
StrideB
=
typename
GemmKernel
::
StrideB
;
};
template
<
typename
Gemm
>
...
...
@@ -99,76 +105,54 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
StrideD
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
StrideC
;
using
LayoutSFA
=
typename
Gemm
::
LayoutSFA
;
using
LayoutSFB
=
typename
Gemm
::
LayoutSFB
;
using
ScaleConfig
=
typename
Gemm
::
ScaleConfig
;
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
auto
prob_shape
=
c3x
::
get_problem_shape
(
a
,
b
);
int32_t
m
=
get
<
0
>
(
prob_shape
),
n
=
get
<
1
>
(
prob_shape
),
k
=
get
<
2
>
(
prob_shape
);
int32_t
m
=
a
.
size
(
0
),
n
=
b
.
size
(
1
),
k
=
a
.
size
(
1
);
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
TORCH_CHECK
(
m
%
4
==
0
,
"m must be divisible by 4"
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideC
=
typename
Gemm
::
StrideC
;
StrideA
a_stride
;
StrideB
b_stride
;
StrideC
c_stride
;
a_stride
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
cute
::
make_shape
(
m
,
k
,
1
));
b_stride
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
cute
::
make_shape
(
n
,
k
,
1
));
c_stride
=
cutlass
::
make_cute_packed_stride
(
StrideC
{},
cute
::
make_shape
(
m
,
n
,
1
));
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
0
};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
0
};
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
LayoutSFA
layout_SFA
=
ScaleConfig
::
tile_atom_to_shape_SFA
(
make_shape
(
m
,
n
,
k
,
1
));
LayoutSFB
layout_SFB
=
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
auto
a_scales_ptr
=
static_cast
<
float
*>
(
a_scales
.
data_ptr
());
auto
b_scales_ptr
=
static_cast
<
float
*>
(
b_scales
.
data_ptr
());
// Check is the t is contiguous and is 1D or 2D with one of the dimensions
// being 1 (i.e. a row or column vector)
auto
is_contiguous_vector
=
[](
const
torch
::
Tensor
&
t
)
{
auto
t_sizes
=
t
.
sizes
();
return
t
.
is_contiguous
()
&&
(
t
.
dim
()
==
1
||
(
t
.
dim
()
==
2
&&
*
std
::
min_element
(
t_sizes
.
begin
(),
t_sizes
.
end
())
==
1
));
};
// TODO(lucas): lets clean-up the kernel so that we pass in Strides so
// we don't have to deal with enforcing implicit layouts
TORCH_CHECK
(
a_scales
.
size
(
0
)
==
m
/
Gemm
::
GroupSizeM
::
value
);
TORCH_CHECK
(
a_scales
.
size
(
1
)
==
k
/
Gemm
::
GroupSizeK
::
value
);
TORCH_CHECK
(
a_scales
.
stride
(
0
)
==
1
||
is_contiguous_vector
(
a_scales
),
"a_scales must be M major"
);
TORCH_CHECK
(
b_scales
.
size
(
0
)
==
k
/
Gemm
::
GroupSizeK
::
value
);
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
n
/
Gemm
::
GroupSizeN
::
value
);
TORCH_CHECK
(
b_scales
.
stride
(
0
)
==
1
||
is_contiguous_vector
(
b_scales
),
"b_scales must be K major"
);
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
,
a_scales_ptr
,
b_scales_ptr
};
auto
mainloop_args
=
[
&
](){
return
typename
GemmKernel
::
MainloopArguments
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
,
a_scales_ptr
,
layout_SFA
,
b_scales_ptr
,
layout_SFB
};
}();
auto
prob_shape
=
cute
::
make_shape
(
m
,
n
,
k
,
1
);
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
typename
GemmKernel
::
TileSchedulerArguments
scheduler
;
static
constexpr
bool
UsesStreamKScheduler
=
cute
::
is_same_v
<
typename
GemmKernel
::
TileSchedulerTag
,
cutlass
::
gemm
::
StreamKScheduler
>
;
if
constexpr
(
UsesStreamKScheduler
)
{
using
DecompositionMode
=
typename
cutlass
::
gemm
::
kernel
::
detail
::
PersistentTileSchedulerSm90StreamKParams
::
DecompositionMode
;
using
ReductionMode
=
typename
cutlass
::
gemm
::
kernel
::
detail
::
PersistentTileSchedulerSm90StreamKParams
::
ReductionMode
;
scheduler
.
decomposition_mode
=
DecompositionMode
::
StreamK
;
scheduler
.
reduction_mode
=
ReductionMode
::
Nondeterministic
;
}
c3x
::
cutlass_gemm_caller
<
GemmKernel
>
(
a
.
device
(),
prob_shape
,
mainloop_args
,
epilogue_args
,
scheduler
);
epilogue_args
);
}
template
<
typename
OutType
>
...
...
@@ -177,18 +161,12 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
auto
k
=
a
.
size
(
1
);
auto
n
=
b
.
size
(
1
);
if
(
k
>
3
*
n
)
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
cutlass
::
gemm
::
StreamKScheduler
,
OutType
,
1
,
128
,
128
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
cutlass
::
gemm
::
PersistentScheduler
,
OutType
,
1
,
128
,
128
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
// TODO: better heuristics
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
128
,
128
,
Shape
<
_128
,
_128
,
_128
>
,
Shape
<
_1
,
_2
,
_1
>
,
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
View file @
38d80967
...
...
@@ -32,7 +32,7 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
TORCH_CHECK
(
a_scales
.
dim
()
==
2
,
"a scale must be 2d tensor."
);
TORCH_CHECK
(
b_scales
.
dim
()
==
2
,
"b scale must be 2d tensor."
);
int32_t
version_num
=
get_sm_version_num
();
if
(
version_num
>=
10
0
)
{
if
(
version_num
>=
9
0
)
{
TORCH_CHECK
(
a
.
size
(
0
)
==
a_scales
.
size
(
0
)
&&
cuda_utils
::
ceil_div
(
a
.
size
(
1
),
int64_t
(
128
))
==
a_scales
.
size
(
1
),
...
...
@@ -41,32 +41,6 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
cuda_utils
::
ceil_div
(
b
.
size
(
0
),
int64_t
(
128
))
==
b_scales
.
size
(
0
)
&&
cuda_utils
::
ceil_div
(
b
.
size
(
1
),
int64_t
(
128
))
==
b_scales
.
size
(
1
),
"b_scale_group_shape must be [128, 128]."
);
}
else
{
// TODO: Remove this after using cutlass sm90 blockwise scaling gemm
// kernel, or introducing ceil_div to the load_init() of mainloop.
using
GroupShape
=
std
::
array
<
int64_t
,
2
>
;
auto
make_group_shape
=
[](
torch
::
Tensor
const
&
x
,
torch
::
Tensor
const
&
s
)
->
GroupShape
{
TORCH_CHECK
(
s
.
dim
()
==
2
,
"cutlass_scaled_mm group scales must be 2D"
);
return
{
cuda_utils
::
ceil_div
(
x
.
size
(
0
),
s
.
size
(
0
)),
cuda_utils
::
ceil_div
(
x
.
size
(
1
),
s
.
size
(
1
))};
};
GroupShape
a_scale_group_shape
=
make_group_shape
(
a
,
a_scales
);
GroupShape
b_scale_group_shape
=
make_group_shape
(
b
,
b_scales
);
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK
((
a_scale_group_shape
==
GroupShape
{
1
,
128
}
&&
b_scale_group_shape
==
GroupShape
{
128
,
128
}
&&
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
&&
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.
\n
"
"a_scale_group_shape must be [1, 128]. Got: ["
,
a_scale_group_shape
[
0
],
", "
,
a_scale_group_shape
[
1
],
"]
\n
"
"b_scale_group_shape must be [128, 128]. Got: ["
,
b_scale_group_shape
[
0
],
", "
,
b_scale_group_shape
[
1
],
"]"
);
}
TORCH_CHECK
(
!
bias
,
"Bias not yet supported blockwise scaled_mm"
);
...
...
csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu
View file @
38d80967
...
...
@@ -26,164 +26,17 @@
#include "dispatch_utils.h"
#include "cuda_utils.h"
#include "nvfp4_utils.cuh"
namespace
vllm
{
// Get type2 from type or vice versa (applied to half and bfloat16)
template
<
typename
T
>
struct
TypeConverter
{
using
Type
=
half2
;
};
// keep for generality
template
<
>
struct
TypeConverter
<
half2
>
{
using
Type
=
c10
::
Half
;
};
template
<
>
struct
TypeConverter
<
c10
::
Half
>
{
using
Type
=
half2
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat162
>
{
using
Type
=
c10
::
BFloat16
;
};
template
<
>
struct
TypeConverter
<
c10
::
BFloat16
>
{
using
Type
=
__nv_bfloat162
;
};
#define ELTS_PER_THREAD 8
constexpr
int
CVT_FP4_ELTS_PER_THREAD
=
8
;
constexpr
int
CVT_FP4_SF_VEC_SIZE
=
16
;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float
(
&
array
)[
8
])
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
]),
"f"
(
array
[
1
]),
"f"
(
array
[
2
]),
"f"
(
array
[
3
]),
"f"
(
array
[
4
]),
"f"
(
array
[
5
]),
"f"
(
array
[
6
]),
"f"
(
array
[
7
]));
return
val
;
#else
return
0
;
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float2
(
&
array
)[
4
])
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
));
return
val
;
#else
return
0
;
#endif
}
// Fast reciprocal.
inline
__device__
float
reciprocal_approximate_ftz
(
float
a
)
{
float
b
;
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;
\n
"
:
"=f"
(
b
)
:
"f"
(
a
));
return
b
;
}
template
<
class
SFType
,
int
CVT_FP4_NUM_THREADS_PER_SF
>
__device__
uint8_t
*
cvt_quant_to_fp4_get_sf_out_offset
(
int
rowIdx
,
int
colIdx
,
int
numCols
,
SFType
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert
(
CVT_FP4_NUM_THREADS_PER_SF
==
1
||
CVT_FP4_NUM_THREADS_PER_SF
==
2
);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if
(
threadIdx
.
x
%
CVT_FP4_NUM_THREADS_PER_SF
==
0
)
{
// SF vector index (16 elements share one SF in the K dimension).
int32_t
kIdx
=
colIdx
/
CVT_FP4_NUM_THREADS_PER_SF
;
int32_t
mIdx
=
rowIdx
;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t
mTileIdx
=
mIdx
/
(
32
*
4
);
// SF vector size 16.
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
int32_t
numKTiles
=
(
numCols
+
factor
-
1
)
/
factor
;
int64_t
mTileStride
=
numKTiles
*
32
*
4
*
4
;
int32_t
kTileIdx
=
(
kIdx
/
4
);
int64_t
kTileStride
=
32
*
4
*
4
;
// M tile layout [32, 4] is column-major.
int32_t
outerMIdx
=
(
mIdx
%
32
);
int64_t
outerMStride
=
4
*
4
;
int32_t
innerMIdx
=
(
mIdx
%
(
32
*
4
))
/
32
;
int64_t
innerMStride
=
4
;
int32_t
innerKIdx
=
(
kIdx
%
4
);
int64_t
innerKStride
=
1
;
// Compute the global offset.
int64_t
SFOffset
=
mTileIdx
*
mTileStride
+
kTileIdx
*
kTileStride
+
outerMIdx
*
outerMStride
+
innerMIdx
*
innerMStride
+
innerKIdx
*
innerKStride
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
}
#endif
return
nullptr
;
}
// Define a 16 bytes packed data type.
template
<
class
Type
>
struct
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
};
template
<
>
struct
PackedVec
<
__nv_fp8_e4m3
>
{
__nv_fp8x2_e4m3
elts
[
8
];
};
template
<
class
Type
>
__inline__
__device__
PackedVec
<
Type
>
compute_silu
(
PackedVec
<
Type
>&
vec
,
PackedVec
<
Type
>&
vec2
)
{
PackedVec
<
Type
>
result
;
#pragma unroll
for
(
int
i
=
0
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
++
i
)
{
if
constexpr
(
std
::
is_same_v
<
Type
,
c10
::
H
alf
>
)
{
if
constexpr
(
std
::
is_same_v
<
Type
,
h
alf
>
)
{
half2
val
(
0.5
f
,
0.5
f
);
half2
t0
=
__hmul2
(
vec
.
elts
[
i
],
val
);
half2
t1
=
__hfma2
(
h2tanh
(
t0
),
val
,
val
);
...
...
@@ -206,13 +59,12 @@ __device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec,
PackedVec
<
Type
>&
vec2
,
float
SFScaleVal
,
uint8_t
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
PackedVec
<
Type
>
out_silu
=
compute_silu
(
vec
,
vec2
);
// Get absolute maximum values among the local 8 values.
auto
localMax
=
__habs2
(
out_silu
.
elts
[
0
]);
// Local maximum value.
#pragma unroll
// Local maximum value.
#pragma unroll
for
(
int
i
=
1
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
localMax
=
__hmax2
(
localMax
,
__habs2
(
out_silu
.
elts
[
i
]));
}
...
...
@@ -259,9 +111,9 @@ __device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec,
// Convert the input to float.
float2
fp2Vals
[
CVT_FP4_ELTS_PER_THREAD
/
2
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
if
constexpr
(
std
::
is_same_v
<
Type
,
c10
::
H
alf
>
)
{
if
constexpr
(
std
::
is_same_v
<
Type
,
h
alf
>
)
{
fp2Vals
[
i
]
=
__half22float2
(
out_silu
.
elts
[
i
]);
}
else
{
fp2Vals
[
i
]
=
__bfloat1622float2
(
out_silu
.
elts
[
i
]);
...
...
@@ -275,22 +127,14 @@ __device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec,
// Write the e2m1 values to global memory.
return
e2m1Vec
;
#else
return
0
;
#endif
}
// Use UE4M3 by default.
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__global__
void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__
(
1024
,
4
)
silu_and_cvt_fp16_to_fp4
(
#else
silu_and_cvt_fp16_to_fp4
(
#endif
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__global__
void
__launch_bounds__
(
1024
,
4
)
silu_and_cvt_fp16_to_fp4
(
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
)
{
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
...
...
@@ -328,22 +172,25 @@ silu_and_cvt_fp16_to_fp4(
in_vec
,
in_vec2
,
SFScaleVal
,
sf_out
);
}
}
#endif
}
}
// namespace vllm
void
silu_and_mul_nvfp4_quant
(
torch
::
Tensor
&
output
,
// [..., d]
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
&
input
,
// [..., 2 * d]
torch
::
Tensor
&
input_sf
)
{
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kFloat16
||
input
.
dtype
()
==
torch
::
kBFloat16
);
void
silu_and_mul_nvfp4_quant_sm1xxa
(
torch
::
Tensor
&
output
,
// [..., d]
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
&
input
,
// [..., 2 * d]
torch
::
Tensor
&
input_sf
)
{
int32_t
m
=
input
.
size
(
0
);
int32_t
n
=
input
.
size
(
1
)
/
2
;
TORCH_CHECK
(
n
%
16
==
0
,
"The N dimension must be multiple of 16."
);
TORCH_CHECK
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
||
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
,
"Unsupported input data type for quantize_to_fp4."
);
int
multiProcessorCount
=
get_device_attribute
(
cudaDevAttrMultiProcessorCount
,
-
1
);
auto
input_sf_ptr
=
static_cast
<
float
const
*>
(
input_sf
.
data_ptr
());
auto
sf_out
=
static_cast
<
int32_t
*>
(
output_sf
.
data_ptr
());
auto
output_ptr
=
static_cast
<
int64_t
*>
(
output
.
data_ptr
());
...
...
@@ -352,17 +199,14 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& output, // [..., d]
dim3
block
(
std
::
min
(
int
(
n
/
ELTS_PER_THREAD
),
1024
));
int
const
numBlocksPerSM
=
2048
/
block
.
x
;
dim3
grid
(
std
::
min
(
int
(
m
),
multiProcessorCount
*
numBlocksPerSM
));
VLLM_DISPATCH_HALF_TYPES
(
input
.
scalar_type
(),
"act_and_mul_quant_kernel"
,
[
&
]
{
auto
input_ptr
=
reinterpret_cast
<
scalar_t
const
*>
(
input
.
data_ptr
());
VLLM_DISPATCH_BYTE_TYPES
(
output
.
scalar_type
(),
"fused_act_and_mul_quant_kernel_nvfp4_type"
,
[
&
]
{
vllm
::
silu_and_cvt_fp16_to_fp4
<
scalar_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
m
,
n
,
input_ptr
,
input_sf_ptr
,
reinterpret_cast
<
uint32_t
*>
(
output_ptr
),
reinterpret_cast
<
uint32_t
*>
(
sf_out
));
});
input
.
scalar_type
(),
"silu_and_mul_nvfp4_quant_kernel"
,
[
&
]
{
using
cuda_type
=
vllm
::
CUDATypeConverter
<
scalar_t
>::
Type
;
auto
input_ptr
=
static_cast
<
cuda_type
const
*>
(
input
.
data_ptr
());
vllm
::
silu_and_cvt_fp16_to_fp4
<
cuda_type
><<<
grid
,
block
,
0
,
stream
>>>
(
m
,
n
,
input_ptr
,
input_sf_ptr
,
reinterpret_cast
<
uint32_t
*>
(
output_ptr
),
reinterpret_cast
<
uint32_t
*>
(
sf_out
));
});
}
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
View file @
38d80967
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <cutlass/arch/arch.h>
...
...
csrc/quantization/fp4/nvfp4_experts_quant.cu
View file @
38d80967
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include "dispatch_utils.h"
template
<
typename
T
>
struct
TypeConverter
{
using
Type
=
half2
;
};
// keep for generality
template
<
>
struct
TypeConverter
<
half2
>
{
using
Type
=
half
;
};
template
<
>
struct
TypeConverter
<
half
>
{
using
Type
=
half2
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat162
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat16
>
{
using
Type
=
__nv_bfloat162
;
};
#define ELTS_PER_THREAD 8
constexpr
int
CVT_FP4_ELTS_PER_THREAD
=
8
;
constexpr
int
CVT_FP4_SF_VEC_SIZE
=
16
;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float
(
&
array
)[
8
])
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
]),
"f"
(
array
[
1
]),
"f"
(
array
[
2
]),
"f"
(
array
[
3
]),
"f"
(
array
[
4
]),
"f"
(
array
[
5
]),
"f"
(
array
[
6
]),
"f"
(
array
[
7
]));
return
val
;
#else
return
0
;
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float2
(
&
array
)[
4
])
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
));
return
val
;
#else
return
0
;
#endif
}
// Fast reciprocal.
inline
__device__
float
reciprocal_approximate_ftz
(
float
a
)
{
float
b
;
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;
\n
"
:
"=f"
(
b
)
:
"f"
(
a
));
return
b
;
}
#include "nvfp4_utils.cuh"
template
<
class
SFType
,
int
CVT_FP4_NUM_THREADS_PER_SF
>
__device__
uint8_t
*
cvt_quant_to_fp4_get_sf_out_offset
(
int
rowIdx
,
int
colIdx
,
int
numCols
,
SFType
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert
(
CVT_FP4_NUM_THREADS_PER_SF
==
1
||
CVT_FP4_NUM_THREADS_PER_SF
==
2
);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if
(
threadIdx
.
x
%
CVT_FP4_NUM_THREADS_PER_SF
==
0
)
{
// SF vector index (16 elements share one SF in the K dimension).
int32_t
kIdx
=
colIdx
/
CVT_FP4_NUM_THREADS_PER_SF
;
int32_t
mIdx
=
rowIdx
;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t
mTileIdx
=
mIdx
/
(
32
*
4
);
// SF vector size 16.
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
int32_t
numKTiles
=
(
numCols
+
factor
-
1
)
/
factor
;
int64_t
mTileStride
=
numKTiles
*
32
*
4
*
4
;
int32_t
kTileIdx
=
(
kIdx
/
4
);
int64_t
kTileStride
=
32
*
4
*
4
;
// M tile layout [32, 4] is column-major.
int32_t
outerMIdx
=
(
mIdx
%
32
);
int64_t
outerMStride
=
4
*
4
;
int32_t
innerMIdx
=
(
mIdx
%
(
32
*
4
))
/
32
;
int64_t
innerMStride
=
4
;
int32_t
innerKIdx
=
(
kIdx
%
4
);
int64_t
innerKStride
=
1
;
// Compute the global offset.
int64_t
SFOffset
=
mTileIdx
*
mTileStride
+
kTileIdx
*
kTileStride
+
outerMIdx
*
outerMStride
+
innerMIdx
*
innerMStride
+
innerKIdx
*
innerKStride
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
}
#endif
return
nullptr
;
}
// Define a 16 bytes packed data type.
template
<
class
Type
>
struct
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
};
template
<
>
struct
PackedVec
<
__nv_fp8_e4m3
>
{
__nv_fp8x2_e4m3
elts
[
8
];
};
// Quantizes the provided PackedVec into the uint32_t output
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__device__
uint32_t
cvt_warp_fp16_to_fp4
(
PackedVec
<
Type
>&
vec
,
float
SFScaleVal
,
uint8_t
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto
localMax
=
__habs2
(
vec
.
elts
[
0
]);
// Local maximum value.
#pragma unroll
for
(
int
i
=
1
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
localMax
=
__hmax2
(
localMax
,
__habs2
(
vec
.
elts
[
i
]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax
=
__hmax2
(
__shfl_xor_sync
(
uint32_t
(
-
1
),
localMax
,
1
),
localMax
);
// Get the final absolute maximum values.
float
vecMax
=
float
(
__hmax
(
localMax
.
x
,
localMax
.
y
));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float
SFValue
=
SFScaleVal
*
(
vecMax
*
reciprocal_approximate_ftz
(
6.0
f
));
// 8 bits representation of the SF.
uint8_t
fp8SFVal
;
// Write the SF to global memory (STG.8).
if
constexpr
(
UE8M0_SF
)
{
// Extract the 8 exponent bits from float32.
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
uint32_t
tmp
=
reinterpret_cast
<
uint32_t
&>
(
SFValue
)
>>
23
;
fp8SFVal
=
tmp
&
0xff
;
// Convert back to fp32.
reinterpret_cast
<
uint32_t
&>
(
SFValue
)
=
tmp
<<
23
;
}
else
{
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3
tmp
=
__nv_fp8_e4m3
(
SFValue
);
reinterpret_cast
<
__nv_fp8_e4m3
&>
(
fp8SFVal
)
=
tmp
;
// Convert back to fp32.
SFValue
=
float
(
tmp
);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float
outputScale
=
SFValue
!=
0
?
reciprocal_approximate_ftz
(
SFValue
*
reciprocal_approximate_ftz
(
SFScaleVal
))
:
0.0
f
;
if
(
SFout
)
{
// Write the SF to global memory (STG.8).
*
SFout
=
fp8SFVal
;
}
// Convert the input to float.
float2
fp2Vals
[
CVT_FP4_ELTS_PER_THREAD
/
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
if
constexpr
(
std
::
is_same_v
<
Type
,
half
>
)
{
fp2Vals
[
i
]
=
__half22float2
(
vec
.
elts
[
i
]);
}
else
{
fp2Vals
[
i
]
=
__bfloat1622float2
(
vec
.
elts
[
i
]);
}
fp2Vals
[
i
].
x
*=
outputScale
;
fp2Vals
[
i
].
y
*=
outputScale
;
}
// Convert to e2m1 values.
uint32_t
e2m1Vec
=
fp32_vec_to_e2m1
(
fp2Vals
);
// Write the e2m1 values to global memory.
return
e2m1Vec
;
#else
return
0
;
#endif
}
namespace
vllm
{
// Use UE4M3 by default.
template
<
class
Type
,
bool
UE8M0_SF
=
false
,
bool
SMALL_NUM_EXPERTS
=
false
>
__global__
void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__
(
512
,
4
)
cvt_fp16_to_fp4
(
#else
cvt_fp16_to_fp4
(
#endif
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
int
n_experts
,
bool
low_latency
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__global__
void
__launch_bounds__
(
512
,
4
)
cvt_fp16_to_fp4
(
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
int
n_experts
,
bool
low_latency
)
{
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
...
...
@@ -299,8 +94,8 @@ cvt_fp16_to_fp4(
&
input_offset_by_experts
[
chunk_start
+
12
]));
local_offsets
[
16
]
=
__ldca
(
&
input_offset_by_experts
[
chunk_start
+
16
]);
// Check against the 16 loaded offsets
#pragma unroll
// Check against the 16 loaded offsets
#pragma unroll
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
if
(
rowIdx
>=
local_offsets
[
i
]
&&
rowIdx
<
local_offsets
[
i
+
1
])
{
rowIdx_in_expert
=
rowIdx
-
local_offsets
[
i
];
...
...
@@ -330,21 +125,15 @@ cvt_fp16_to_fp4(
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
SFScaleVal
,
sf_out
);
}
#endif
}
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
template
<
class
Type
,
bool
UE8M0_SF
=
false
,
bool
SMALL_NUM_EXPERTS
=
false
>
__global__
void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__
(
1024
,
4
)
cvt_fp16_to_fp4
(
#else
cvt_fp16_to_fp4
(
#endif
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
int
n_experts
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__global__
void
__launch_bounds__
(
1024
,
4
)
cvt_fp16_to_fp4
(
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
int
n_experts
)
{
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
...
...
@@ -425,7 +214,6 @@ cvt_fp16_to_fp4(
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
SFScaleVal
,
sf_out
);
}
#endif
}
template
<
typename
T
>
...
...
@@ -501,6 +289,8 @@ void quant_impl(void* output, void* output_scale, void* input,
}
}
}
// namespace vllm
/*Quantization entry for fp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
...
...
@@ -560,23 +350,17 @@ void scaled_fp4_experts_quant_sm100a(
// 4 means 4 fp8 values are packed into one int32
TORCH_CHECK
(
output_scale
.
size
(
1
)
*
4
==
padded_k
);
auto
in_dtype
=
input
.
dtype
();
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input
.
get_device
());
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
quant_impl
<
half
>
(
output
.
data_ptr
(),
output_scale
.
data_ptr
(),
input
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
m_topk
,
k
,
n_experts
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
quant_impl
<
__nv_bfloat16
>
(
output
.
data_ptr
(),
output_scale
.
data_ptr
(),
input
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
m_topk
,
k
,
n_experts
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Expected input data type to be half or bfloat16"
);
}
VLLM_DISPATCH_HALF_TYPES
(
input
.
scalar_type
(),
"nvfp4_experts_quant_kernel"
,
[
&
]
{
using
cuda_type
=
vllm
::
CUDATypeConverter
<
scalar_t
>::
Type
;
vllm
::
quant_impl
<
cuda_type
>
(
output
.
data_ptr
(),
output_scale
.
data_ptr
(),
input
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
m_topk
,
k
,
n_experts
,
stream
);
});
}
csrc/quantization/fp4/nvfp4_quant_entry.cu
View file @
38d80967
...
...
@@ -32,6 +32,14 @@ void scaled_fp4_experts_quant_sm100a(
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void
silu_and_mul_nvfp4_quant_sm1xxa
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
input_sf
);
#endif
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
...
...
@@ -54,3 +62,13 @@ void scaled_fp4_experts_quant(
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 experts quantization kernel"
);
}
void
silu_and_mul_nvfp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
input_sf
)
{
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return
silu_and_mul_nvfp4_quant_sm1xxa
(
output
,
output_sf
,
input
,
input_sf
);
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled silu_and_mul nvfp4 quantization kernel"
);
}
csrc/quantization/fp4/nvfp4_quant_kernels.cu
View file @
38d80967
...
...
@@ -23,245 +23,18 @@
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h"
#include "nvfp4_utils.cuh"
// Get type2 from type or vice versa (applied to half and bfloat16)
template
<
typename
T
>
struct
TypeConverter
{
using
Type
=
half2
;
};
// keep for generality
template
<
>
struct
TypeConverter
<
half2
>
{
using
Type
=
half
;
};
template
<
>
struct
TypeConverter
<
half
>
{
using
Type
=
half2
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat162
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat16
>
{
using
Type
=
__nv_bfloat162
;
};
#define ELTS_PER_THREAD 8
constexpr
int
CVT_FP4_ELTS_PER_THREAD
=
8
;
constexpr
int
CVT_FP4_SF_VEC_SIZE
=
16
;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float
(
&
array
)[
8
])
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
]),
"f"
(
array
[
1
]),
"f"
(
array
[
2
]),
"f"
(
array
[
3
]),
"f"
(
array
[
4
]),
"f"
(
array
[
5
]),
"f"
(
array
[
6
]),
"f"
(
array
[
7
]));
return
val
;
#else
return
0
;
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float2
(
&
array
)[
4
])
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
));
return
val
;
#else
return
0
;
#endif
}
// Fast reciprocal.
inline
__device__
float
reciprocal_approximate_ftz
(
float
a
)
{
float
b
;
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;
\n
"
:
"=f"
(
b
)
:
"f"
(
a
));
return
b
;
}
template
<
class
SFType
,
int
CVT_FP4_NUM_THREADS_PER_SF
>
__device__
uint8_t
*
cvt_quant_to_fp4_get_sf_out_offset
(
int
rowIdx
,
int
colIdx
,
int
numCols
,
SFType
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert
(
CVT_FP4_NUM_THREADS_PER_SF
==
1
||
CVT_FP4_NUM_THREADS_PER_SF
==
2
);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if
(
threadIdx
.
x
%
CVT_FP4_NUM_THREADS_PER_SF
==
0
)
{
// SF vector index (16 elements share one SF in the K dimension).
int32_t
kIdx
=
colIdx
/
CVT_FP4_NUM_THREADS_PER_SF
;
int32_t
mIdx
=
rowIdx
;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t
mTileIdx
=
mIdx
/
(
32
*
4
);
// SF vector size 16.
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
int32_t
numKTiles
=
(
numCols
+
factor
-
1
)
/
factor
;
int64_t
mTileStride
=
numKTiles
*
32
*
4
*
4
;
int32_t
kTileIdx
=
(
kIdx
/
4
);
int64_t
kTileStride
=
32
*
4
*
4
;
// M tile layout [32, 4] is column-major.
int32_t
outerMIdx
=
(
mIdx
%
32
);
int64_t
outerMStride
=
4
*
4
;
int32_t
innerMIdx
=
(
mIdx
%
(
32
*
4
))
/
32
;
int64_t
innerMStride
=
4
;
int32_t
innerKIdx
=
(
kIdx
%
4
);
int64_t
innerKStride
=
1
;
// Compute the global offset.
int64_t
SFOffset
=
mTileIdx
*
mTileStride
+
kTileIdx
*
kTileStride
+
outerMIdx
*
outerMStride
+
innerMIdx
*
innerMStride
+
innerKIdx
*
innerKStride
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
}
#endif
return
nullptr
;
}
// Define a 16 bytes packed data type.
template
<
class
Type
>
struct
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
};
template
<
>
struct
PackedVec
<
__nv_fp8_e4m3
>
{
__nv_fp8x2_e4m3
elts
[
8
];
};
// Quantizes the provided PackedVec into the uint32_t output
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__device__
uint32_t
cvt_warp_fp16_to_fp4
(
PackedVec
<
Type
>&
vec
,
float
SFScaleVal
,
uint8_t
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto
localMax
=
__habs2
(
vec
.
elts
[
0
]);
// Local maximum value.
#pragma unroll
for
(
int
i
=
1
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
localMax
=
__hmax2
(
localMax
,
__habs2
(
vec
.
elts
[
i
]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax
=
__hmax2
(
__shfl_xor_sync
(
uint32_t
(
-
1
),
localMax
,
1
),
localMax
);
// Get the final absolute maximum values.
float
vecMax
=
float
(
__hmax
(
localMax
.
x
,
localMax
.
y
));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float
SFValue
=
SFScaleVal
*
(
vecMax
*
reciprocal_approximate_ftz
(
6.0
f
));
// 8 bits representation of the SF.
uint8_t
fp8SFVal
;
// Write the SF to global memory (STG.8).
if
constexpr
(
UE8M0_SF
)
{
// Extract the 8 exponent bits from float32.
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
uint32_t
tmp
=
reinterpret_cast
<
uint32_t
&>
(
SFValue
)
>>
23
;
fp8SFVal
=
tmp
&
0xff
;
// Convert back to fp32.
reinterpret_cast
<
uint32_t
&>
(
SFValue
)
=
tmp
<<
23
;
}
else
{
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3
tmp
=
__nv_fp8_e4m3
(
SFValue
);
reinterpret_cast
<
__nv_fp8_e4m3
&>
(
fp8SFVal
)
=
tmp
;
// Convert back to fp32.
SFValue
=
float
(
tmp
);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float
outputScale
=
SFValue
!=
0
?
reciprocal_approximate_ftz
(
SFValue
*
reciprocal_approximate_ftz
(
SFScaleVal
))
:
0.0
f
;
if
(
SFout
)
{
// Write the SF to global memory (STG.8).
*
SFout
=
fp8SFVal
;
}
// Convert the input to float.
float2
fp2Vals
[
CVT_FP4_ELTS_PER_THREAD
/
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
if
constexpr
(
std
::
is_same_v
<
Type
,
half
>
)
{
fp2Vals
[
i
]
=
__half22float2
(
vec
.
elts
[
i
]);
}
else
{
fp2Vals
[
i
]
=
__bfloat1622float2
(
vec
.
elts
[
i
]);
}
fp2Vals
[
i
].
x
*=
outputScale
;
fp2Vals
[
i
].
y
*=
outputScale
;
}
// Convert to e2m1 values.
uint32_t
e2m1Vec
=
fp32_vec_to_e2m1
(
fp2Vals
);
// Write the e2m1 values to global memory.
return
e2m1Vec
;
#else
return
0
;
#endif
}
namespace
vllm
{
// Use UE4M3 by default.
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__global__
void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__
(
512
,
4
)
cvt_fp16_to_fp4
(
#else
cvt_fp16_to_fp4
(
#endif
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__global__
void
__launch_bounds__
(
512
,
4
)
cvt_fp16_to_fp4
(
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
)
{
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
...
...
@@ -293,7 +66,6 @@ cvt_fp16_to_fp4(
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
SFScaleVal
,
sf_out
);
}
}
#endif
}
template
<
typename
T
>
...
...
@@ -332,6 +104,8 @@ template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input,
int
multiProcessorCount
,
cudaStream_t
stream
);
}
// namespace vllm
void
scaled_fp4_quant_sm1xxa
(
torch
::
Tensor
const
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
output_sf
,
...
...
@@ -340,6 +114,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
int32_t
n
=
input
.
size
(
1
);
TORCH_CHECK
(
n
%
16
==
0
,
"The N dimension must be multiple of 16."
);
TORCH_CHECK
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
||
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
,
"Unsupported input data type for quantize_to_fp4."
);
int
multiProcessorCount
=
get_device_attribute
(
cudaDevAttrMultiProcessorCount
,
-
1
);
...
...
@@ -353,24 +130,10 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
// We don't support e8m0 scales at this moment.
bool
useUE8M0
=
false
;
switch
(
input
.
scalar_type
())
{
case
torch
::
kHalf
:
{
auto
input_ptr
=
reinterpret_cast
<
half
const
*>
(
input
.
data_ptr
());
invokeFP4Quantization
(
m
,
n
,
input_ptr
,
input_sf_ptr
,
output_ptr
,
sf_out
,
useUE8M0
,
multiProcessorCount
,
stream
);
break
;
}
case
torch
::
kBFloat16
:
{
auto
input_ptr
=
reinterpret_cast
<
__nv_bfloat16
const
*>
(
input
.
data_ptr
());
invokeFP4Quantization
(
m
,
n
,
input_ptr
,
input_sf_ptr
,
output_ptr
,
sf_out
,
useUE8M0
,
multiProcessorCount
,
stream
);
break
;
}
default:
{
std
::
cerr
<<
"Observing: "
<<
input
.
scalar_type
()
<<
" for the input datatype which is invalid"
;
throw
std
::
runtime_error
(
"Unsupported input data type for quantize_to_fp4."
);
}
}
VLLM_DISPATCH_HALF_TYPES
(
input
.
scalar_type
(),
"nvfp4_quant_kernel"
,
[
&
]
{
using
cuda_type
=
vllm
::
CUDATypeConverter
<
scalar_t
>::
Type
;
auto
input_ptr
=
static_cast
<
cuda_type
const
*>
(
input
.
data_ptr
());
vllm
::
invokeFP4Quantization
(
m
,
n
,
input_ptr
,
input_sf_ptr
,
output_ptr
,
sf_out
,
useUE8M0
,
multiProcessorCount
,
stream
);
});
}
csrc/quantization/fp4/nvfp4_utils.cuh
0 → 100644
View file @
38d80967
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#define ELTS_PER_THREAD 8
constexpr
int
CVT_FP4_ELTS_PER_THREAD
=
8
;
constexpr
int
CVT_FP4_SF_VEC_SIZE
=
16
;
namespace
vllm
{
// Convert PyTorch cpp type to CUDA type
template
<
typename
T
>
struct
CUDATypeConverter
{
using
Type
=
T
;
};
template
<
>
struct
CUDATypeConverter
<
at
::
Half
>
{
using
Type
=
half
;
};
template
<
>
struct
CUDATypeConverter
<
at
::
BFloat16
>
{
using
Type
=
__nv_bfloat16
;
};
// Get type2 from type or vice versa (applied to half and bfloat16)
template
<
typename
T
>
struct
TypeConverter
{
using
Type
=
half2
;
};
// keep for generality
template
<
>
struct
TypeConverter
<
half2
>
{
using
Type
=
half
;
};
template
<
>
struct
TypeConverter
<
half
>
{
using
Type
=
half2
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat162
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat16
>
{
using
Type
=
__nv_bfloat162
;
};
// Define a 16 bytes packed data type.
template
<
class
Type
>
struct
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
};
template
<
>
struct
PackedVec
<
__nv_fp8_e4m3
>
{
__nv_fp8x2_e4m3
elts
[
8
];
};
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float
(
&
array
)[
8
])
{
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
]),
"f"
(
array
[
1
]),
"f"
(
array
[
2
]),
"f"
(
array
[
3
]),
"f"
(
array
[
4
]),
"f"
(
array
[
5
]),
"f"
(
array
[
6
]),
"f"
(
array
[
7
]));
return
val
;
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float2
(
&
array
)[
4
])
{
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
));
return
val
;
}
// Fast reciprocal.
inline
__device__
float
reciprocal_approximate_ftz
(
float
a
)
{
float
b
;
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;
\n
"
:
"=f"
(
b
)
:
"f"
(
a
));
return
b
;
}
template
<
class
SFType
,
int
CVT_FP4_NUM_THREADS_PER_SF
>
__device__
uint8_t
*
cvt_quant_to_fp4_get_sf_out_offset
(
int
rowIdx
,
int
colIdx
,
int
numCols
,
SFType
*
SFout
)
{
static_assert
(
CVT_FP4_NUM_THREADS_PER_SF
==
1
||
CVT_FP4_NUM_THREADS_PER_SF
==
2
);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if
(
threadIdx
.
x
%
CVT_FP4_NUM_THREADS_PER_SF
==
0
)
{
// SF vector index (16 elements share one SF in the K dimension).
int32_t
kIdx
=
colIdx
/
CVT_FP4_NUM_THREADS_PER_SF
;
int32_t
mIdx
=
rowIdx
;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t
mTileIdx
=
mIdx
/
(
32
*
4
);
// SF vector size 16.
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
int32_t
numKTiles
=
(
numCols
+
factor
-
1
)
/
factor
;
int64_t
mTileStride
=
numKTiles
*
32
*
4
*
4
;
int32_t
kTileIdx
=
(
kIdx
/
4
);
int64_t
kTileStride
=
32
*
4
*
4
;
// M tile layout [32, 4] is column-major.
int32_t
outerMIdx
=
(
mIdx
%
32
);
int64_t
outerMStride
=
4
*
4
;
int32_t
innerMIdx
=
(
mIdx
%
(
32
*
4
))
/
32
;
int64_t
innerMStride
=
4
;
int32_t
innerKIdx
=
(
kIdx
%
4
);
int64_t
innerKStride
=
1
;
// Compute the global offset.
int64_t
SFOffset
=
mTileIdx
*
mTileStride
+
kTileIdx
*
kTileStride
+
outerMIdx
*
outerMStride
+
innerMIdx
*
innerMStride
+
innerKIdx
*
innerKStride
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
}
return
nullptr
;
}
// Quantizes the provided PackedVec into the uint32_t output
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__device__
uint32_t
cvt_warp_fp16_to_fp4
(
PackedVec
<
Type
>&
vec
,
float
SFScaleVal
,
uint8_t
*
SFout
)
{
// Get absolute maximum values among the local 8 values.
auto
localMax
=
__habs2
(
vec
.
elts
[
0
]);
// Local maximum value.
#pragma unroll
for
(
int
i
=
1
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
localMax
=
__hmax2
(
localMax
,
__habs2
(
vec
.
elts
[
i
]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax
=
__hmax2
(
__shfl_xor_sync
(
uint32_t
(
-
1
),
localMax
,
1
),
localMax
);
// Get the final absolute maximum values.
float
vecMax
=
float
(
__hmax
(
localMax
.
x
,
localMax
.
y
));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float
SFValue
=
SFScaleVal
*
(
vecMax
*
reciprocal_approximate_ftz
(
6.0
f
));
// 8 bits representation of the SF.
uint8_t
fp8SFVal
;
// Write the SF to global memory (STG.8).
if
constexpr
(
UE8M0_SF
)
{
// Extract the 8 exponent bits from float32.
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
uint32_t
tmp
=
reinterpret_cast
<
uint32_t
&>
(
SFValue
)
>>
23
;
fp8SFVal
=
tmp
&
0xff
;
// Convert back to fp32.
reinterpret_cast
<
uint32_t
&>
(
SFValue
)
=
tmp
<<
23
;
}
else
{
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3
tmp
=
__nv_fp8_e4m3
(
SFValue
);
reinterpret_cast
<
__nv_fp8_e4m3
&>
(
fp8SFVal
)
=
tmp
;
// Convert back to fp32.
SFValue
=
float
(
tmp
);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float
outputScale
=
SFValue
!=
0
?
reciprocal_approximate_ftz
(
SFValue
*
reciprocal_approximate_ftz
(
SFScaleVal
))
:
0.0
f
;
if
(
SFout
)
{
// Write the SF to global memory (STG.8).
*
SFout
=
fp8SFVal
;
}
// Convert the input to float.
float2
fp2Vals
[
CVT_FP4_ELTS_PER_THREAD
/
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
if
constexpr
(
std
::
is_same_v
<
Type
,
half
>
)
{
fp2Vals
[
i
]
=
__half22float2
(
vec
.
elts
[
i
]);
}
else
{
fp2Vals
[
i
]
=
__bfloat1622float2
(
vec
.
elts
[
i
]);
}
fp2Vals
[
i
].
x
*=
outputScale
;
fp2Vals
[
i
].
y
*=
outputScale
;
}
// Convert to e2m1 values.
uint32_t
e2m1Vec
=
fp32_vec_to_e2m1
(
fp2Vals
);
// Write the e2m1 values to global memory.
return
e2m1Vec
;
}
}
// namespace vllm
csrc/quantization/machete/generate.py
View file @
38d80967
...
...
@@ -417,7 +417,7 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
))
def
prepacked_type_key
(
prepack_type
:
PrepackTypeConfig
):
# For now
we
we can just use the first accumulator type seen since
# For now
,
we can just use the first accumulator type seen since
# the tensor core shapes/layouts don't vary based on accumulator
# type so we can generate less code this way
return
(
prepack_type
.
a
,
prepack_type
.
b_num_bits
,
prepack_type
.
convert
)
...
...
csrc/torch_bindings.cpp
View file @
38d80967
...
...
@@ -115,8 +115,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
// ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
#ifndef USE_ROCM
ops
.
def
(
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
"Tensor input, Tensor input_global_scale) -> ()"
);
...
...
@@ -169,6 +168,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()"
);
ops
.
impl
(
"fused_add_rms_norm"
,
torch
::
kCUDA
,
&
fused_add_rms_norm
);
// Polynomial Normalization.
ops
.
def
(
"poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, float "
"epsilon) -> ()"
);
ops
.
impl
(
"poly_norm"
,
torch
::
kCUDA
,
&
poly_norm
);
// Apply repetition penalties to logits in-place
ops
.
def
(
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
...
...
@@ -521,10 +526,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// SM100 CUTLASS MLA decode
ops
.
def
(
"sm100_cutlass_mla_decode(Tensor! out, Tensor
q_nop
e, Tensor q_pe,"
" Tensor kv_c_and_k_pe_cache,
Tensor seq_lens,
"
" Tensor
page_table, Tensor workspace, float
"
"scale,"
"sm100_cutlass_mla_decode(Tensor! out, Tensor
! ls
e, Tensor q_
no
pe,"
" Tensor
q_pe, Tensor
kv_c_and_k_pe_cache,"
" Tensor
seq_lens, Tensor page_table,
"
"
Tensor workspace, float
scale,"
" int num_kv_splits) -> ()"
);
// conditionally compiled so impl in source file
...
...
@@ -698,16 +703,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor scale) -> ()"
);
cache_ops
.
impl
(
"concat_and_cache_mla"
,
torch
::
kCUDA
,
&
concat_and_cache_mla
);
cache_ops
.
def
(
"cp_fused_concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor cp_local_token_select_indices,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()"
);
cache_ops
.
impl
(
"cp_fused_concat_and_cache_mla"
,
torch
::
kCUDA
,
&
cp_fused_concat_and_cache_mla
);
// Convert the key and value cache to fp8 data type.
cache_ops
.
def
(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
...
...
docker/Dockerfile
View file @
38d80967
...
...
@@ -237,7 +237,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
# Check the size of the wheel if RUN_WHEEL_CHECK is true
COPY
.buildkite/check-wheel-size.py check-wheel-size.py
# sync the default value with .buildkite/check-wheel-size.py
ARG
VLLM_MAX_SIZE_MB=4
0
0
ARG
VLLM_MAX_SIZE_MB=4
5
0
ENV
VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB
ARG
RUN_WHEEL_CHECK=true
RUN if
[
"
$RUN_WHEEL_CHECK
"
=
"true"
]
;
then
\
...
...
@@ -261,6 +261,8 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match"
# Use copy mode to avoid hardlink failures with Docker cache mounts
ENV
UV_LINK_MODE=copy
# Install libnuma-dev, required by fastsafetensors (fixes #20384)
RUN
apt-get update
&&
apt-get
install
-y
libnuma-dev
&&
rm
-rf
/var/lib/apt/lists/
*
COPY
requirements/lint.txt requirements/lint.txt
COPY
requirements/test.txt requirements/test.txt
COPY
requirements/dev.txt requirements/dev.txt
...
...
@@ -373,7 +375,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
# Install FlashInfer from source
ARG
FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
# Keep this in sync with "flashinfer" extra in setup.py
ARG
FLASHINFER_GIT_REF="v0.
2.14.post1
"
ARG
FLASHINFER_GIT_REF="v0.
3.0
"
# Flag to control whether to compile FlashInfer AOT kernels
# Set to "true" to enable AOT compilation:
# docker build --build-arg FLASHINFER_AOT_COMPILE=true ...
...
...
@@ -432,11 +434,10 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--extra-index-url
${
PYTORCH_CUDA_INDEX_BASE_URL
}
/cu
$(
echo
$CUDA_VERSION
|
cut
-d
.
-f1
,2 |
tr
-d
'.'
)
# Install DeepGEMM from source
ARG
DEEPGEMM_GIT_REF
="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c"
ARG
DEEPGEMM_GIT_REF
COPY
tools/install_deepgemm.sh /tmp/install_deepgemm.sh
RUN
--mount
=
type
=
cache,target
=
/root/.cache/uv
\
VLLM_DOCKER_BUILD_CONTEXT
=
1 /tmp/install_deepgemm.sh
--cuda-version
"
${
CUDA_VERSION
}
"
--ref
"
${
DEEPGEMM_GIT_REF
}
"
\
&&
rm
/tmp/install_deepgemm.sh
VLLM_DOCKER_BUILD_CONTEXT
=
1 /tmp/install_deepgemm.sh
--cuda-version
"
${
CUDA_VERSION
}
"
${
DEEPGEMM_GIT_REF
:+--ref
"
$DEEPGEMM_GIT_REF
"
}
# Install EP kernels(pplx-kernels and DeepEP), NixL
COPY
tools/ep_kernels/install_python_libraries.sh install_python_libraries.sh
...
...
@@ -518,7 +519,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
else
\
BITSANDBYTES_VERSION
=
"0.46.1"
;
\
fi
;
\
uv pip
install
--system
accelerate hf_transfer modelscope
"bitsandbytes>=
${
BITSANDBYTES_VERSION
}
"
'timm
==0.9
.1
0
'
boto3 runai-model-streamer runai-model-streamer[s3]
uv pip
install
--system
accelerate hf_transfer modelscope
"bitsandbytes>=
${
BITSANDBYTES_VERSION
}
"
'timm
>=1.0
.1
7
'
boto3 runai-model-streamer runai-model-streamer[s3]
ENV
VLLM_USAGE_SOURCE production-docker-image
...
...
docker/Dockerfile.neuron
deleted
100644 → 0
View file @
33650733
# default base image
# https://gallery.ecr.aws/neuron/pytorch-inference-neuronx
ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.6.0-neuronx-py310-sdk2.23.0-ubuntu22.04"
FROM $BASE_IMAGE
RUN echo "Base image is $BASE_IMAGE"
# Install some basic utilities
RUN apt-get update && \
apt-get install -y \
git \
python3 \
python3-pip \
ffmpeg libsm6 libxext6 libgl1
### Mount Point ###
# When launching the container, mount the code directory to /workspace
ARG APP_MOUNT=/workspace
VOLUME [ ${APP_MOUNT} ]
WORKDIR ${APP_MOUNT}/vllm
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas tenacity
RUN python3 -m pip install neuronx-cc==2.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
RUN python3 -m pip install pytest
# uninstall transformers-neuronx package explicitly to avoid version conflict
RUN python3 -m pip uninstall -y transformers-neuronx
COPY . .
ARG GIT_REPO_CHECK=0
RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
RUN python3 -m pip install -U \
'cmake>=3.26.1' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
-r requirements/neuron.txt
ENV VLLM_TARGET_DEVICE neuron
RUN --mount=type=bind,source=.git,target=.git \
pip install --no-build-isolation -v -e .
# install development dependencies (for testing)
RUN python3 -m pip install -e tests/vllm_test_utils
# install transformers-neuronx package as an optional dependencies (for V0)
# FIXME: `--no-deps` argument is temporarily added to resolve transformers package version conflict
RUN python3 -m pip install transformers-neuronx==0.13.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U --no-deps
RUN python3 -m pip install sentencepiece transformers==4.48.0 -U
# overwrite entrypoint to run bash script
RUN echo "import subprocess; import sys; subprocess.check_call(sys.argv[1:])" > /usr/local/bin/dockerd-entrypoint.py
CMD ["/bin/bash"]
Prev
1
2
3
4
5
6
7
8
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment