Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a1c8f379
Unverified
Commit
a1c8f379
authored
Mar 11, 2025
by
Jeff Daily
Committed by
GitHub
Mar 11, 2025
Browse files
dynamic distpatch of fp8 kernels (#14245)
Signed-off-by:
Jeff Daily
<
jeff.daily@amd.com
>
parent
08a1a112
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
236 additions
and
146 deletions
+236
-146
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+1
-2
csrc/dispatch_utils.h
csrc/dispatch_utils.h
+26
-6
csrc/layernorm_quant_kernels.cu
csrc/layernorm_quant_kernels.cu
+33
-25
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+22
-0
csrc/quantization/fp8/common.cu
csrc/quantization/fp8/common.cu
+43
-25
csrc/quantization/fp8/common.cuh
csrc/quantization/fp8/common.cuh
+59
-27
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
.../fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
+3
-0
csrc/quantization/fused_kernels/quant_conversions.cuh
csrc/quantization/fused_kernels/quant_conversions.cuh
+11
-8
csrc/quantization/vectorization.cuh
csrc/quantization/vectorization.cuh
+0
-1
tests/kernels/quant_utils.py
tests/kernels/quant_utils.py
+1
-2
tests/kernels/test_triton_scaled_mm.py
tests/kernels/test_triton_scaled_mm.py
+2
-5
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+3
-6
vllm/_custom_ops.py
vllm/_custom_ops.py
+17
-18
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+3
-3
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+1
-2
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+2
-2
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+1
-1
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+5
-9
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+1
-2
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
...cutor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
+2
-2
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
a1c8f379
...
@@ -18,8 +18,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import *
...
@@ -18,8 +18,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import *
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
FP8_DTYPE
=
torch
.
float8_e4m3fnuz
if
current_platform
.
is_rocm
(
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
)
else
torch
.
float8_e4m3fn
class
BenchmarkConfig
(
TypedDict
):
class
BenchmarkConfig
(
TypedDict
):
...
...
csrc/dispatch_utils.h
View file @
a1c8f379
...
@@ -6,6 +6,11 @@
...
@@ -6,6 +6,11 @@
#include <torch/all.h>
#include <torch/all.h>
// Need a special dispatch case macro since we will nest the FP8 dispatch.
// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
#define AT_DISPATCH_FP8_CASE(enum_type, ...) \
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
...
@@ -14,17 +19,32 @@
...
@@ -14,17 +19,32 @@
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
// TODO(luka/varun): use FP8_TYPE macro after refactoring
// ROCm devices might use either fn or fnuz, so set up dispatch table for both.
#ifndef USE_ROCM
// A host-based check at runtime will create a preferred FP8 type for ROCm
// such that the correct kernel is dispatched.
#ifdef USE_ROCM
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#else
#else
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn
uz
, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#endif
#endif
// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
// See AT_DISPATCH_FP8_CASE above.
#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
...
...
csrc/layernorm_quant_kernels.cu
View file @
a1c8f379
...
@@ -21,9 +21,9 @@
...
@@ -21,9 +21,9 @@
namespace
vllm
{
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
typename
fp8_type
>
__global__
void
rms_norm_static_fp8_quant_kernel
(
__global__
void
rms_norm_static_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
*
__restrict__
scale
,
// [1]
const
float
*
__restrict__
scale
,
// [1]
...
@@ -52,7 +52,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
...
@@ -52,7 +52,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
scaled_fp8_conversion
<
true
>
(
out_norm
,
scale_inv
);
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
}
}
}
}
...
@@ -60,10 +60,10 @@ __global__ void rms_norm_static_fp8_quant_kernel(
...
@@ -60,10 +60,10 @@ __global__ void rms_norm_static_fp8_quant_kernel(
Additional optimizations we can make in this case are
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
packed and vectorized operations, which help with the
memory latency bottleneck. */
memory latency bottleneck. */
template
<
typename
scalar_t
,
int
width
>
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_static_fp8_quant_kernel
(
fused_add_rms_norm_static_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
...
@@ -114,7 +114,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
...
@@ -114,7 +114,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
out
[
id
*
width
+
i
]
=
out
[
id
*
width
+
i
]
=
scaled_fp8_conversion
<
true
>
(
float
(
temp
.
data
[
i
]),
scale_inv
);
scaled_fp8_conversion
<
true
,
fp8_type
>
(
float
(
temp
.
data
[
i
]),
scale_inv
);
}
}
}
}
}
}
...
@@ -122,10 +122,10 @@ fused_add_rms_norm_static_fp8_quant_kernel(
...
@@ -122,10 +122,10 @@ fused_add_rms_norm_static_fp8_quant_kernel(
/* Generic fused_add_rms_norm_kernel
/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
The width field is not used here but necessary for other specializations.
*/
*/
template
<
typename
scalar_t
,
int
width
>
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_static_fp8_quant_kernel
(
fused_add_rms_norm_static_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
...
@@ -158,7 +158,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
...
@@ -158,7 +158,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
float
x
=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x
=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
scaled_fp8_conversion
<
true
>
(
out_norm
,
scale_inv
);
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
}
}
}
}
...
@@ -176,25 +176,33 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
...
@@ -176,25 +176,33 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
VLLM_DISPATCH_FLOATING_TYPES
(
vllm
::
rms_norm_static_fp8_quant_kernel
<
scalar_t
>
input
.
scalar_type
(),
"rms_norm_kernel_scalar_type"
,
[
&
]
{
VLLM_DISPATCH_FP8_TYPES
(
out
.
scalar_type
(),
"rms_norm_kernel_fp8_type"
,
[
&
]
{
vllm
::
rms_norm_static_fp8_quant_kernel
<
scalar_t
,
fp8_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
FP8_TYPE
>
(),
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
fp8_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
epsilon
,
weight
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
num_tokens
,
hidden_size
);
epsilon
,
num_tokens
,
hidden_size
);
});
});
});
}
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, width> \
VLLM_DISPATCH_FP8_TYPES( \
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, \
width, fp8_t> \
<<<grid, block, 0, stream>>>( \
<<<grid, block, 0, stream>>>( \
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), \
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
scale.data_ptr<float>(), epsilon, num_tokens, hidden_size); \
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
epsilon, num_tokens, hidden_size); \
}); \
});
});
void
fused_add_rms_norm_static_fp8_quant
(
void
fused_add_rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size],
torch
::
Tensor
&
out
,
// [..., hidden_size],
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
...
...
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
a1c8f379
...
@@ -13,6 +13,28 @@ namespace vllm {
...
@@ -13,6 +13,28 @@ namespace vllm {
namespace
fp8
{
namespace
fp8
{
#ifdef ENABLE_FP8
#ifdef ENABLE_FP8
// Use hardware cvt instruction for fp8 on rocm
template
<
typename
fp8_type
>
__device__
__forceinline__
fp8_type
cvt_c10
(
float
const
r
)
{
return
{};
}
template
<
>
__device__
__forceinline__
c10
::
Float8_e4m3fn
cvt_c10
(
float
const
r
)
{
return
c10
::
Float8_e4m3fn
(
__hip_cvt_float_to_fp8
(
r
,
__hip_fp8_e4m3
::
__default_saturation
,
__hip_fp8_e4m3
::
__default_interpret
),
c10
::
Float8_e4m3fn
::
from_bits
());
}
template
<
>
__device__
__forceinline__
c10
::
Float8_e4m3fnuz
cvt_c10
(
float
const
r
)
{
return
c10
::
Float8_e4m3fnuz
(
__hip_cvt_float_to_fp8
(
r
,
__hip_fp8_e4m3_fnuz
::
__default_saturation
,
__hip_fp8_e4m3_fnuz
::
__default_interpret
),
c10
::
Float8_e4m3fnuz
::
from_bits
());
}
template
<
typename
Tout
,
typename
Tin
>
template
<
typename
Tout
,
typename
Tin
>
__inline__
__device__
Tout
vec_conversion
(
const
Tin
&
x
)
{
__inline__
__device__
Tout
vec_conversion
(
const
Tin
&
x
)
{
return
x
;
return
x
;
...
...
csrc/quantization/fp8/common.cu
View file @
a1c8f379
...
@@ -11,8 +11,8 @@
...
@@ -11,8 +11,8 @@
namespace
vllm
{
namespace
vllm
{
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
typename
fp8_type
>
__global__
void
scaled_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
__global__
void
scaled_fp8_quant_kernel
(
fp8_type
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
const
scalar_t
*
__restrict__
input
,
const
float
*
__restrict__
scale
,
const
float
*
__restrict__
scale
,
int64_t
num_elems
)
{
int64_t
num_elems
)
{
...
@@ -25,12 +25,13 @@ __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
...
@@ -25,12 +25,13 @@ __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
out
,
input
,
inverted_scale
,
num_elems
,
tid
,
blockDim
.
x
*
gridDim
.
x
);
out
,
input
,
inverted_scale
,
num_elems
,
tid
,
blockDim
.
x
*
gridDim
.
x
);
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
typename
fp8_type
>
__global__
void
dynamic_per_token_scaled_fp8_quant_kernel
(
__global__
void
dynamic_per_token_scaled_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
float
*
__restrict__
scale
,
fp8_type
*
__restrict__
out
,
float
*
__restrict__
scale
,
scalar_t
const
*
__restrict__
input
,
float
const
*
__restrict__
scale_ub
,
scalar_t
const
*
__restrict__
input
,
float
const
*
__restrict__
scale_ub
,
const
int
hidden_size
)
{
const
int
hidden_size
)
{
float
const
min_scaling_factor
=
1.0
f
/
(
FP8_E4M3_MAX
*
512.
f
);
float
const
min_scaling_factor
=
1.0
f
/
(
fp8_e4m3_adjusted_max_v
<
fp8_type
>
*
512.
f
);
int
const
tid
=
threadIdx
.
x
;
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
...
@@ -38,7 +39,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
...
@@ -38,7 +39,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
// Use int64 to avoid overflowing an int32 when calculating this offset
// Use int64 to avoid overflowing an int32 when calculating this offset
int64_t
offset
=
static_cast
<
int64_t
>
(
token_idx
)
*
hidden_size
;
int64_t
offset
=
static_cast
<
int64_t
>
(
token_idx
)
*
hidden_size
;
scalar_t
const
*
__restrict__
token_input
=
&
input
[
offset
];
scalar_t
const
*
__restrict__
token_input
=
&
input
[
offset
];
FP8_TYPE
*
__restrict__
token_output
=
&
out
[
offset
];
fp8_type
*
__restrict__
token_output
=
&
out
[
offset
];
// For vectorization, token_input and token_output pointers need to be
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
// aligned at 8-byte and 4-byte addresses respectively.
...
@@ -66,7 +67,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
...
@@ -66,7 +67,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
token_scale
=
block_absmax_val_maybe
;
token_scale
=
block_absmax_val_maybe
;
}
}
// token scale computation
// token scale computation
token_scale
=
max
(
token_scale
/
FP8_E4M3_MAX
,
min_scaling_factor
);
token_scale
=
max
(
token_scale
/
fp8_e4m3_adjusted_max_v
<
fp8_type
>
,
min_scaling_factor
);
scale
[
token_idx
]
=
token_scale
;
scale
[
token_idx
]
=
token_scale
;
}
}
__syncthreads
();
__syncthreads
();
...
@@ -77,7 +79,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
...
@@ -77,7 +79,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
token_output
,
token_input
,
token_scale
,
hidden_size
,
tid
,
blockDim
.
x
);
token_output
,
token_input
,
token_scale
,
hidden_size
,
tid
,
blockDim
.
x
);
}
else
{
}
else
{
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
token_output
[
i
]
=
scaled_fp8_conversion
<
false
>
(
token_output
[
i
]
=
scaled_fp8_conversion
<
false
,
fp8_type
>
(
static_cast
<
float
>
(
token_input
[
i
]),
token_scale
);
static_cast
<
float
>
(
token_input
[
i
]),
token_scale
);
}
}
}
}
...
@@ -96,11 +98,15 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
...
@@ -96,11 +98,15 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"scaled_fp8_quant_kernel"
,
[
&
]
{
input
.
scalar_type
(),
"scaled_fp8_quant_kernel_scalar_type"
,
[
&
]
{
vllm
::
scaled_fp8_quant_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
VLLM_DISPATCH_FP8_TYPES
(
out
.
data_ptr
<
FP8_TYPE
>
(),
input
.
data_ptr
<
scalar_t
>
(),
out
.
scalar_type
(),
"scaled_fp8_quant_kernel_fp8_type"
,
[
&
]
{
vllm
::
scaled_fp8_quant_kernel
<
scalar_t
,
fp8_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
fp8_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
num_elems
);
scale
.
data_ptr
<
float
>
(),
num_elems
);
});
});
});
}
}
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., d]
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., d]
...
@@ -114,13 +120,19 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
...
@@ -114,13 +120,19 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"scaled_fp8_quant_kernel"
,
[
&
]
{
input
.
scalar_type
(),
"scaled_fp8_quant_kernel_scalar_type"
,
[
&
]
{
vllm
::
segmented_max_reduction
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
VLLM_DISPATCH_FP8_TYPES
(
scale
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_t
>
(),
num_elems
);
out
.
scalar_type
(),
"scaled_fp8_quant_kernel_fp8_type"
,
[
&
]
{
vllm
::
scaled_fp8_quant_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
vllm
::
segmented_max_reduction
<
scalar_t
,
fp8_t
>
out
.
data_ptr
<
FP8_TYPE
>
(),
input
.
data_ptr
<
scalar_t
>
(),
<<<
grid
,
block
,
0
,
stream
>>>
(
scale
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_t
>
(),
num_elems
);
vllm
::
scaled_fp8_quant_kernel
<
scalar_t
,
fp8_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
fp8_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
num_elems
);
scale
.
data_ptr
<
float
>
(),
num_elems
);
});
});
});
}
}
void
dynamic_per_token_scaled_fp8_quant
(
void
dynamic_per_token_scaled_fp8_quant
(
...
@@ -138,12 +150,18 @@ void dynamic_per_token_scaled_fp8_quant(
...
@@ -138,12 +150,18 @@ void dynamic_per_token_scaled_fp8_quant(
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"dynamic_per_token_scaled_fp8_quant_kernel"
,
[
&
]
{
input
.
scalar_type
(),
vllm
::
dynamic_per_token_scaled_fp8_quant_kernel
<
scalar_t
>
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type"
,
[
&
]
{
VLLM_DISPATCH_FP8_TYPES
(
out
.
scalar_type
(),
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type"
,
[
&
]
{
vllm
::
dynamic_per_token_scaled_fp8_quant_kernel
<
scalar_t
,
fp8_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
FP8_TYPE
>
(),
scales
.
data_ptr
<
float
>
(),
out
.
data_ptr
<
fp8_t
>
(),
scales
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale_ub
.
has_value
()
?
scale_ub
->
data_ptr
<
float
>
()
:
nullptr
,
scale_ub
.
has_value
()
?
scale_ub
->
data_ptr
<
float
>
()
:
nullptr
,
hidden_size
);
hidden_size
);
});
});
});
}
}
csrc/quantization/fp8/common.cuh
View file @
a1c8f379
...
@@ -7,18 +7,52 @@
...
@@ -7,18 +7,52 @@
#ifndef USE_ROCM
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fn.h>
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
#define MAYBE_HOST_DEVICE C10_HOST_DEVICE
C10_HOST_DEVICE
constexpr
auto
FP8_E4M3_MAX
=
std
::
numeric_limits
<
FP8_TYPE
>::
max
();
#else
#else
#include <ATen/hip/HIPContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include "amd/quant_utils.cuh"
#include "amd/quant_utils.cuh"
using
FP8_TYPE
=
c10
::
Float8_e4m3fnuz
;
// ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr
// Using the default max value from pytorch (240.0) will cause accuracy
#define MAYBE_HOST_DEVICE
// issue when running dynamic quantization. Here use 224.0f for rocm.
#endif
constexpr
auto
FP8_E4M3_MAX
=
224.0
f
;
// Determines the preferred FP8 type for the current platform.
// Note that for CUDA this just returns true,
// but on ROCm it will check device props.
static
bool
is_fp8_ocp
()
{
#ifndef USE_ROCM
return
true
;
#else
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
std
::
string
device_arch
=
dprops
->
gcnArchName
;
size_t
substring
=
device_arch
.
find
(
"gfx94"
);
return
substring
==
std
::
string
::
npos
;
#endif
#endif
constexpr
static
auto
kFp8Type
=
c10
::
CppTypeToScalarType
<
FP8_TYPE
>::
value
;
}
template
<
typename
T
>
struct
fp8_e4m3_adjusted_max
;
template
<
>
struct
fp8_e4m3_adjusted_max
<
c10
::
Float8_e4m3fn
>
{
static
constexpr
c10
::
Float8_e4m3fn
val
()
{
return
std
::
numeric_limits
<
c10
::
Float8_e4m3fn
>::
max
();
}
};
// Using the default max value from pytorch (240.0 0x7F) will cause accuracy
// issues when running dynamic quantization. Here use 224.0 0x7E for rocm.
template
<
>
struct
fp8_e4m3_adjusted_max
<
c10
::
Float8_e4m3fnuz
>
{
static
constexpr
c10
::
Float8_e4m3fnuz
val
()
{
return
c10
::
Float8_e4m3fnuz
(
0x7E
,
c10
::
Float8_e4m3fnuz
::
from_bits
());
}
};
template
<
typename
T
>
MAYBE_HOST_DEVICE
static
constexpr
T
fp8_e4m3_adjusted_max_v
=
fp8_e4m3_adjusted_max
<
T
>::
val
();
namespace
vllm
{
namespace
vllm
{
...
@@ -32,8 +66,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
...
@@ -32,8 +66,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
return
old
;
return
old
;
}
}
template
<
bool
is_scale_inverted
>
template
<
bool
is_scale_inverted
,
typename
fp8_type
>
__device__
__forceinline__
FP8_TYPE
scaled_fp8_conversion
(
float
const
val
,
__device__
__forceinline__
fp8_type
scaled_fp8_conversion
(
float
const
val
,
float
const
scale
)
{
float
const
scale
)
{
float
x
=
0.0
f
;
float
x
=
0.0
f
;
if
constexpr
(
is_scale_inverted
)
{
if
constexpr
(
is_scale_inverted
)
{
...
@@ -42,15 +76,13 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
...
@@ -42,15 +76,13 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
x
=
val
/
scale
;
x
=
val
/
scale
;
}
}
float
r
=
fmax
(
-
FP8_E4M3_MAX
,
fmin
(
x
,
FP8_E4M3_MAX
));
float
r
=
fmax
(
-
fp8_e4m3_adjusted_max_v
<
fp8_type
>
,
fmin
(
x
,
fp8_e4m3_adjusted_max_v
<
fp8_type
>
));
#ifndef USE_ROCM
#ifndef USE_ROCM
return
static_cast
<
c10
::
Float8_e4m3fn
>
(
r
);
return
static_cast
<
fp8_type
>
(
r
);
#else
#else
// Use hardware cvt instruction for fp8 on rocm
// Use hardware cvt instruction for fp8 on rocm
return
c10
::
Float8_e4m3fnuz
(
return
fp8
::
cvt_c10
<
fp8_type
>
(
r
);
__hip_cvt_float_to_fp8
(
r
,
fp8
::
fp8_type
::
__default_saturation
,
fp8
::
fp8_type
::
__default_interpret
),
c10
::
Float8_e4m3fnuz
::
from_bits
());
#endif
#endif
}
}
...
@@ -60,7 +92,7 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
...
@@ -60,7 +92,7 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
// So to get the right answer, *scale needs to be initialized to
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
// finish before consuming *scale.
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
typename
fp8_type
>
__global__
void
segmented_max_reduction
(
float
*
__restrict__
scale
,
__global__
void
segmented_max_reduction
(
float
*
__restrict__
scale
,
const
scalar_t
*
__restrict__
input
,
const
scalar_t
*
__restrict__
input
,
int64_t
num_elems
)
{
int64_t
num_elems
)
{
...
@@ -91,7 +123,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
...
@@ -91,7 +123,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
// Finally, since cache[0] contains the maximum for this thread block,
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
// atomically write the max to the target location
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
atomicMaxFloat
(
scale
,
cache
[
0
]
/
FP8_E4M3_MAX
);
atomicMaxFloat
(
scale
,
cache
[
0
]
/
fp8_e4m3_adjusted_max_v
<
fp8_type
>
);
}
}
}
}
...
@@ -123,13 +155,13 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
...
@@ -123,13 +155,13 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
return
absmax_val
;
return
absmax_val
;
}
}
template
<
typename
scalar_t
,
bool
is_scale_inverted
>
template
<
typename
scalar_t
,
bool
is_scale_inverted
,
typename
fp8_type
>
__device__
void
scaled_fp8_conversion_vec
(
FP8_TYPE
*
__restrict__
out
,
__device__
void
scaled_fp8_conversion_vec
(
fp8_type
*
__restrict__
out
,
scalar_t
const
*
__restrict__
input
,
scalar_t
const
*
__restrict__
input
,
float
const
scale
,
float
const
scale
,
int64_t
const
num_elems
,
int64_t
const
num_elems
,
int
const
tid
,
int
const
step
)
{
int
const
tid
,
int
const
step
)
{
using
float8x4_t
=
q8x4_t
<
FP8_TYPE
>
;
using
float8x4_t
=
q8x4_t
<
fp8_type
>
;
// Vectorized input/output to better utilize memory bandwidth.
// Vectorized input/output to better utilize memory bandwidth.
auto
const
*
vectorized_in
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
auto
const
*
vectorized_in
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
auto
*
vectorized_out
=
reinterpret_cast
<
float8x4_t
*>
(
out
);
auto
*
vectorized_out
=
reinterpret_cast
<
float8x4_t
*>
(
out
);
...
@@ -141,20 +173,20 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
...
@@ -141,20 +173,20 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
float8x4_t
out_vec
;
float8x4_t
out_vec
;
out_vec
.
x
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
out_vec
.
x
=
scaled_fp8_conversion
<
is_scale_inverted
,
fp8_type
>
(
static_cast
<
float
>
(
in_vec
.
x
),
scale
);
static_cast
<
float
>
(
in_vec
.
x
),
scale
);
out_vec
.
y
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
out_vec
.
y
=
scaled_fp8_conversion
<
is_scale_inverted
,
fp8_type
>
(
static_cast
<
float
>
(
in_vec
.
y
),
scale
);
static_cast
<
float
>
(
in_vec
.
y
),
scale
);
out_vec
.
z
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
out_vec
.
z
=
scaled_fp8_conversion
<
is_scale_inverted
,
fp8_type
>
(
static_cast
<
float
>
(
in_vec
.
z
),
scale
);
static_cast
<
float
>
(
in_vec
.
z
),
scale
);
out_vec
.
w
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
out_vec
.
w
=
scaled_fp8_conversion
<
is_scale_inverted
,
fp8_type
>
(
static_cast
<
float
>
(
in_vec
.
w
),
scale
);
static_cast
<
float
>
(
in_vec
.
w
),
scale
);
vectorized_out
[
i
]
=
out_vec
;
vectorized_out
[
i
]
=
out_vec
;
}
}
// Handle the remaining elements if num_elems is not divisible by 4
// Handle the remaining elements if num_elems is not divisible by 4
for
(
int64_t
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
for
(
int64_t
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
out
[
i
]
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
out
[
i
]
=
scaled_fp8_conversion
<
is_scale_inverted
,
fp8_type
>
(
static_cast
<
float
>
(
input
[
i
]),
scale
);
static_cast
<
float
>
(
input
[
i
]),
scale
);
}
}
}
}
...
...
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
View file @
a1c8f379
...
@@ -144,6 +144,9 @@ void rms_norm_dynamic_per_token_quant(
...
@@ -144,6 +144,9 @@ void rms_norm_dynamic_per_token_quant(
torch
::
Tensor
&
scales
,
// [num_tokens]
torch
::
Tensor
&
scales
,
// [num_tokens]
double
const
var_epsilon
,
// Variance epsilon used in norm calculation
double
const
var_epsilon
,
// Variance epsilon used in norm calculation
std
::
optional
<
at
::
Tensor
>
scale_ub
,
std
::
optional
<
at
::
Tensor
>
residual
)
{
std
::
optional
<
at
::
Tensor
>
scale_ub
,
std
::
optional
<
at
::
Tensor
>
residual
)
{
static
c10
::
ScalarType
kFp8Type
=
is_fp8_ocp
()
?
c10
::
ScalarType
::
Float8_e4m3fn
:
c10
::
ScalarType
::
Float8_e4m3fnuz
;
TORCH_CHECK
(
out
.
dtype
()
==
kFp8Type
||
out
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
out
.
dtype
()
==
kFp8Type
||
out
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
out
.
is_contiguous
()
&&
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
()
&&
input
.
is_contiguous
());
...
...
csrc/quantization/fused_kernels/quant_conversions.cuh
View file @
a1c8f379
...
@@ -31,9 +31,11 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
...
@@ -31,9 +31,11 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
#endif
#endif
}
}
static
__device__
__forceinline__
FP8_TYPE
float_to_fp8
(
float
const
x
)
{
template
<
typename
fp8_type
>
float
const
r
=
fmax
(
-
FP8_E4M3_MAX
,
fmin
(
x
,
FP8_E4M3_MAX
));
static
__device__
__forceinline__
fp8_type
float_to_fp8
(
float
const
x
)
{
return
static_cast
<
FP8_TYPE
>
(
r
);
float
const
r
=
fmax
(
-
fp8_e4m3_adjusted_max_v
<
fp8_type
>
,
fmin
(
x
,
fp8_e4m3_adjusted_max_v
<
fp8_type
>
));
return
static_cast
<
fp8_type
>
(
r
);
}
}
template
<
typename
quant_type_t
,
bool
is_scale_inverted
,
typename
enable
=
void
>
template
<
typename
quant_type_t
,
bool
is_scale_inverted
,
typename
enable
=
void
>
...
@@ -54,15 +56,16 @@ struct ScaledQuant<
...
@@ -54,15 +56,16 @@ struct ScaledQuant<
};
};
template
<
typename
quant_type_t
,
bool
is_scale_inverted
>
template
<
typename
quant_type_t
,
bool
is_scale_inverted
>
struct
ScaledQuant
<
struct
ScaledQuant
<
quant_type_t
,
is_scale_inverted
,
quant_type_t
,
is_scale_inverted
,
typename
std
::
enable_if_t
<
typename
std
::
enable_if_t
<
std
::
is_same_v
<
quant_type_t
,
FP8_TYPE
>>>
{
std
::
is_same_v
<
quant_type_t
,
c10
::
Float8_e4m3fn
>
||
std
::
is_same_v
<
quant_type_t
,
c10
::
Float8_e4m3fnuz
>>>
{
static
__device__
__forceinline__
quant_type_t
quant_fn
(
float
const
x
,
static
__device__
__forceinline__
quant_type_t
quant_fn
(
float
const
x
,
float
const
scale
)
{
float
const
scale
)
{
if
constexpr
(
is_scale_inverted
)
{
if
constexpr
(
is_scale_inverted
)
{
return
float_to_fp8
(
x
*
scale
);
return
float_to_fp8
<
quant_type_t
>
(
x
*
scale
);
}
else
{
}
else
{
return
float_to_fp8
(
x
/
scale
);
return
float_to_fp8
<
quant_type_t
>
(
x
/
scale
);
}
}
}
}
};
};
...
...
csrc/quantization/vectorization.cuh
View file @
a1c8f379
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
*/
*/
// Include both AMD and NVIDIA fp8 types to avoid circular import
// Include both AMD and NVIDIA fp8 types to avoid circular import
// TODO(luka/varun) use FP8_TYPE instead after refactoring
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fn.h>
...
...
tests/kernels/quant_utils.py
View file @
a1c8f379
...
@@ -9,8 +9,7 @@ from vllm.platforms import current_platform
...
@@ -9,8 +9,7 @@ from vllm.platforms import current_platform
# Using the default value (240.0) from pytorch will cause accuracy
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8_MAX
=
224.0
ROCM_FP8_MAX
=
224.0
FP8_DTYPE
=
torch
.
float8_e4m3fnuz
if
current_platform
.
is_rocm
()
\
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
else
torch
.
float8_e4m3fn
def
as_float32_tensor
(
x
:
Union
[
float
,
torch
.
tensor
])
->
torch
.
tensor
:
def
as_float32_tensor
(
x
:
Union
[
float
,
torch
.
tensor
])
->
torch
.
tensor
:
...
...
tests/kernels/test_triton_scaled_mm.py
View file @
a1c8f379
...
@@ -32,11 +32,8 @@ def scaled_mm_torch(a: torch.Tensor,
...
@@ -32,11 +32,8 @@ def scaled_mm_torch(a: torch.Tensor,
def
get_8bit_types
():
def
get_8bit_types
():
types
=
[
torch
.
int8
]
types
=
[
torch
.
int8
]
supports_fp8
=
current_platform
.
has_device_capability
(
89
)
if
current_platform
.
supports_fp8
():
if
current_platform
.
is_rocm
()
and
supports_fp8
:
types
.
append
(
current_platform
.
fp8_dtype
())
types
.
append
(
torch
.
float8_e4m3fnuz
)
elif
current_platform
.
is_cuda
()
and
supports_fp8
:
types
.
append
(
torch
.
float8_e4m3fn
)
return
types
return
types
...
...
tests/quantization/test_fp8.py
View file @
a1c8f379
...
@@ -103,8 +103,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
...
@@ -103,8 +103,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
assert
attn
.
_v_scale
==
1.0
assert
attn
.
_v_scale
==
1.0
if
current_platform
.
is_cuda
():
if
current_platform
.
is_cuda
():
if
current_platform
.
has_device_capability
(
if
current_platform
.
supports_fp8
()
and
not
force_marlin
:
89
)
and
not
force_marlin
:
# For GPUs with hardware support, we keep weights in fp8
# For GPUs with hardware support, we keep weights in fp8
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
else
:
else
:
...
@@ -112,11 +111,9 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
...
@@ -112,11 +111,9 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
# for weight-only quantization using Marlin kernels
# for weight-only quantization using Marlin kernels
assert
fc1
.
weight
.
dtype
==
torch
.
int32
assert
fc1
.
weight
.
dtype
==
torch
.
int32
elif
current_platform
.
is_rocm
():
elif
current_platform
.
is_rocm
():
# Only MI300 and above support quantization='fp8'
if
current_platform
.
supports_fp8
()
and
not
force_marlin
:
if
current_platform
.
has_device_capability
(
94
)
and
not
force_marlin
:
# For GPUs with hardware support, we keep weights in fp8
# For GPUs with hardware support, we keep weights in fp8
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fnuz
assert
fc1
.
weight
.
dtype
==
current_platform
.
fp8_dtype
()
else
:
# unsupported ROCm platform
else
:
# unsupported ROCm platform
pytest
.
skip
(
pytest
.
skip
(
"Skip `test_load_fp16_model`. "
"Skip `test_load_fp16_model`. "
...
...
vllm/_custom_ops.py
View file @
a1c8f379
...
@@ -875,9 +875,8 @@ def scaled_fp8_quant(
...
@@ -875,9 +875,8 @@ def scaled_fp8_quant(
# This code assumes batch_dim and num_tokens are flattened
# This code assumes batch_dim and num_tokens are flattened
assert
(
input
.
ndim
==
2
)
assert
(
input
.
ndim
==
2
)
shape
:
Union
[
tuple
[
int
,
int
],
torch
.
Size
]
=
input
.
shape
shape
:
Union
[
tuple
[
int
,
int
],
torch
.
Size
]
=
input
.
shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype
:
torch
.
dtype
=
torch
.
float8_e4m3fnuz
\
out_dtype
:
torch
.
dtype
=
current_platform
.
fp8_dtype
()
if
current_platform
.
is_rocm
()
else
torch
.
float8_e4m3fn
if
num_token_padding
:
if
num_token_padding
:
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
])
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
])
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
out_dtype
)
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
out_dtype
)
...
...
vllm/attention/backends/mla/common.py
View file @
a1c8f379
...
@@ -226,7 +226,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
...
@@ -226,7 +226,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8
)
CompressedTensorsW8A8Fp8
)
from
vllm.model_executor.layers.quantization.fp8
import
Fp8LinearMethod
from
vllm.model_executor.layers.quantization.fp8
import
Fp8LinearMethod
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
Fp8LinearGenericOp
,
current_platform_fp8_dtype
,
is_fp8
)
Fp8LinearGenericOp
,
is_fp8
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
scaled_quantize
)
scaled_quantize
)
from
vllm.model_executor.layers.rotary_embedding
import
(
from
vllm.model_executor.layers.rotary_embedding
import
(
...
@@ -1238,7 +1238,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1238,7 +1238,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
W_Q_UK
,
W_Q_UK_scales
=
scaled_quantize
(
W_Q_UK
,
W_Q_UK_scales
=
scaled_quantize
(
W_Q_UK
,
W_Q_UK
,
self
.
reqaunt_weight_group_shape
,
self
.
reqaunt_weight_group_shape
,
quant_dtype
=
current_platform
_
fp8_dtype
)
quant_dtype
=
current_platform
.
fp8_dtype
()
)
# For FP8 save the transpose so we can use
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
# `apply_w8a8_block_fp8_linear` directly
self
.
W_Q_UK
=
W_Q_UK
.
T
.
contiguous
()
self
.
W_Q_UK
=
W_Q_UK
.
T
.
contiguous
()
...
@@ -1255,7 +1255,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1255,7 +1255,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
W_UV_O
,
W_UV_O_scales
=
scaled_quantize
(
W_UV_O
,
W_UV_O_scales
=
scaled_quantize
(
W_UV_O
,
W_UV_O
,
self
.
reqaunt_weight_group_shape
,
self
.
reqaunt_weight_group_shape
,
quant_dtype
=
current_platform
_
fp8_dtype
)
quant_dtype
=
current_platform
.
fp8_dtype
()
)
# For FP8 save the transpose so we can use
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
# `apply_w8a8_block_fp8_linear` directly
self
.
W_UV_O
=
W_UV_O
.
T
.
contiguous
()
self
.
W_UV_O
=
W_UV_O
.
T
.
contiguous
()
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
a1c8f379
...
@@ -158,8 +158,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -158,8 +158,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# If rocm, normalize the weights and scales to e4m3fnuz
if
current_platform
.
is_fp8_fnuz
():
if
current_platform
.
is_rocm
():
# Normalize the weights and scales
# Normalize the weights and scales
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
\
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
normalize_e4m3fn_to_e4m3fnuz
(
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
a1c8f379
...
@@ -42,7 +42,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -42,7 +42,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
logical_widths
=
layer
.
logical_widths
,
logical_widths
=
layer
.
logical_widths
,
)
)
if
current_platform
.
is_
rocm
():
if
current_platform
.
is_
fp8_fnuz
():
input_scale
=
getattr
(
layer
,
'input_scale'
,
None
)
input_scale
=
getattr
(
layer
,
'input_scale'
,
None
)
weight
,
max_w_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
,
max_w_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
...
@@ -60,7 +60,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -60,7 +60,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight
=
layer
.
weight
weight
=
layer
.
weight
if
current_platform
.
is_
rocm
():
if
current_platform
.
is_
fp8_fnuz
():
input_scale
=
getattr
(
layer
,
'input_scale'
,
None
)
input_scale
=
getattr
(
layer
,
'input_scale'
,
None
)
weight
,
weight_scale
,
input_scale
=
\
weight
,
weight_scale
,
input_scale
=
\
...
...
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
a1c8f379
...
@@ -127,7 +127,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
...
@@ -127,7 +127,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
weight
=
layer
.
weight
weight
=
layer
.
weight
if
current_platform
.
is_
rocm
():
if
current_platform
.
is_
fp8_fnuz
():
weight
,
weight_scale
,
input_scale
=
\
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight
=
weight
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
a1c8f379
...
@@ -270,7 +270,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -270,7 +270,7 @@ class Fp8LinearMethod(LinearMethodBase):
# TODO(rob): refactor block quant into separate class.
# TODO(rob): refactor block quant into separate class.
if
self
.
block_quant
:
if
self
.
block_quant
:
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
if
current_platform
.
is_
rocm
():
if
current_platform
.
is_
fp8_fnuz
():
weight
,
weight_scale_inv
,
_
=
\
weight
,
weight_scale_inv
,
_
=
\
normalize_e4m3fn_to_e4m3fnuz
(
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
layer
.
weight
,
weight
=
layer
.
weight
,
...
@@ -327,8 +327,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -327,8 +327,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight
=
layer
.
weight
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
weight_scale
=
layer
.
weight_scale
# If rocm, use float8_e4m3fnuz.
if
current_platform
.
is_fp8_fnuz
():
if
current_platform
.
is_rocm
():
weight
,
weight_scale
,
input_scale
=
\
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight
=
weight
,
...
@@ -533,7 +532,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -533,7 +532,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# TODO (rob): refactor block quant into separate class.
# TODO (rob): refactor block quant into separate class.
if
self
.
block_quant
:
if
self
.
block_quant
:
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
if
current_platform
.
is_
rocm
():
if
current_platform
.
is_
fp8_fnuz
():
w13_weight
,
w13_weight_scale_inv
,
w13_input_scale
=
\
w13_weight
,
w13_weight_scale_inv
,
w13_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w13_weight
,
layer
.
w13_weight_scale_inv
,
layer
.
w13_weight
,
layer
.
w13_weight_scale_inv
,
...
@@ -559,9 +558,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -559,9 +558,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# If checkpoint is fp16, quantize in place.
# If checkpoint is fp16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype
=
current_platform
.
fp8_dtype
()
fp8_dtype
=
torch
.
float8_e4m3fnuz
\
if
current_platform
.
is_rocm
()
else
torch
.
float8_e4m3fn
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
fp8_dtype
)
dtype
=
fp8_dtype
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
fp8_dtype
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
fp8_dtype
)
...
@@ -608,8 +605,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -608,8 +605,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# If rocm, normalize the weights and scales to e4m3fnuz
if
current_platform
.
is_fp8_fnuz
():
if
current_platform
.
is_rocm
():
# Normalize the weights and scales
# Normalize the weights and scales
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
\
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
normalize_e4m3fn_to_e4m3fnuz
(
...
...
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
a1c8f379
...
@@ -142,8 +142,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
...
@@ -142,8 +142,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# If rocm, normalize the weights and scales to e4m3fnuz
if
current_platform
.
is_fp8_fnuz
():
if
current_platform
.
is_rocm
():
# Normalize the weights and scales
# Normalize the weights and scales
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
\
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
normalize_e4m3fn_to_e4m3fnuz
(
...
...
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
View file @
a1c8f379
...
@@ -39,7 +39,7 @@ class QuarkW8A8Fp8(QuarkScheme):
...
@@ -39,7 +39,7 @@ class QuarkW8A8Fp8(QuarkScheme):
logical_widths
=
layer
.
logical_widths
,
logical_widths
=
layer
.
logical_widths
,
)
)
if
current_platform
.
is_
rocm
():
if
current_platform
.
is_
fp8_fnuz
():
weight
,
max_w_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
,
max_w_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight
=
weight
,
weight_scale
=
max_w_scale
,
weight_scale
=
max_w_scale
,
...
@@ -55,7 +55,7 @@ class QuarkW8A8Fp8(QuarkScheme):
...
@@ -55,7 +55,7 @@ class QuarkW8A8Fp8(QuarkScheme):
elif
self
.
qscheme
==
"per_channel"
:
elif
self
.
qscheme
==
"per_channel"
:
weight
=
layer
.
weight
weight
=
layer
.
weight
if
current_platform
.
is_
rocm
():
if
current_platform
.
is_
fp8_fnuz
():
weight
,
weight_scale
,
input_scale
=
\
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight
=
weight
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment