Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a1c8f379
Unverified
Commit
a1c8f379
authored
Mar 11, 2025
by
Jeff Daily
Committed by
GitHub
Mar 11, 2025
Browse files
dynamic distpatch of fp8 kernels (#14245)
Signed-off-by:
Jeff Daily
<
jeff.daily@amd.com
>
parent
08a1a112
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
236 additions
and
146 deletions
+236
-146
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+1
-2
csrc/dispatch_utils.h
csrc/dispatch_utils.h
+26
-6
csrc/layernorm_quant_kernels.cu
csrc/layernorm_quant_kernels.cu
+33
-25
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+22
-0
csrc/quantization/fp8/common.cu
csrc/quantization/fp8/common.cu
+43
-25
csrc/quantization/fp8/common.cuh
csrc/quantization/fp8/common.cuh
+59
-27
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
.../fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
+3
-0
csrc/quantization/fused_kernels/quant_conversions.cuh
csrc/quantization/fused_kernels/quant_conversions.cuh
+11
-8
csrc/quantization/vectorization.cuh
csrc/quantization/vectorization.cuh
+0
-1
tests/kernels/quant_utils.py
tests/kernels/quant_utils.py
+1
-2
tests/kernels/test_triton_scaled_mm.py
tests/kernels/test_triton_scaled_mm.py
+2
-5
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+3
-6
vllm/_custom_ops.py
vllm/_custom_ops.py
+17
-18
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+3
-3
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+1
-2
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+2
-2
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+1
-1
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+5
-9
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+1
-2
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
...cutor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
+2
-2
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
a1c8f379
...
...
@@ -18,8 +18,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import *
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
FP8_DTYPE
=
torch
.
float8_e4m3fnuz
if
current_platform
.
is_rocm
(
)
else
torch
.
float8_e4m3fn
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
class
BenchmarkConfig
(
TypedDict
):
...
...
csrc/dispatch_utils.h
View file @
a1c8f379
...
...
@@ -6,6 +6,11 @@
#include <torch/all.h>
// Need a special dispatch case macro since we will nest the FP8 dispatch.
// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
#define AT_DISPATCH_FP8_CASE(enum_type, ...) \
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
...
...
@@ -14,17 +19,32 @@
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
// TODO(luka/varun): use FP8_TYPE macro after refactoring
#ifndef USE_ROCM
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#else
// ROCm devices might use either fn or fnuz, so set up dispatch table for both.
// A host-based check at runtime will create a preferred FP8 type for ROCm
// such that the correct kernel is dispatched.
#ifdef USE_ROCM
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#else
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#endif
// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
// See AT_DISPATCH_FP8_CASE above.
#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
...
...
csrc/layernorm_quant_kernels.cu
View file @
a1c8f379
...
...
@@ -21,9 +21,9 @@
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
typename
fp8_type
>
__global__
void
rms_norm_static_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
*
__restrict__
scale
,
// [1]
...
...
@@ -52,7 +52,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
scaled_fp8_conversion
<
true
>
(
out_norm
,
scale_inv
);
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
}
}
...
...
@@ -60,10 +60,10 @@ __global__ void rms_norm_static_fp8_quant_kernel(
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck. */
template
<
typename
scalar_t
,
int
width
>
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_static_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
...
...
@@ -114,7 +114,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
out
[
id
*
width
+
i
]
=
scaled_fp8_conversion
<
true
>
(
float
(
temp
.
data
[
i
]),
scale_inv
);
scaled_fp8_conversion
<
true
,
fp8_type
>
(
float
(
temp
.
data
[
i
]),
scale_inv
);
}
}
}
...
...
@@ -122,10 +122,10 @@ fused_add_rms_norm_static_fp8_quant_kernel(
/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template
<
typename
scalar_t
,
int
width
>
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_static_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
...
...
@@ -158,7 +158,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
float
x
=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
scaled_fp8_conversion
<
true
>
(
out_norm
,
scale_inv
);
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
}
}
...
...
@@ -176,25 +176,33 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
vllm
::
rms_norm_static_fp8_quant_kernel
<
scalar_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
FP8_TYPE
>
(),
input
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel_scalar_type"
,
[
&
]
{
VLLM_DISPATCH_FP8_TYPES
(
out
.
scalar_type
(),
"rms_norm_kernel_fp8_type"
,
[
&
]
{
vllm
::
rms_norm_static_fp8_quant_kernel
<
scalar_t
,
fp8_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
fp8_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
});
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
scale.data_ptr<float>(), epsilon, num_tokens, hidden_size); \
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
VLLM_DISPATCH_FP8_TYPES( \
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, \
width, fp8_t> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
epsilon, num_tokens, hidden_size); \
}); \
});
void
fused_add_rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size],
torch
::
Tensor
&
input
,
// [..., hidden_size]
...
...
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
a1c8f379
...
...
@@ -13,6 +13,28 @@ namespace vllm {
namespace
fp8
{
#ifdef ENABLE_FP8
// Use hardware cvt instruction for fp8 on rocm
template
<
typename
fp8_type
>
__device__
__forceinline__
fp8_type
cvt_c10
(
float
const
r
)
{
return
{};
}
template
<
>
__device__
__forceinline__
c10
::
Float8_e4m3fn
cvt_c10
(
float
const
r
)
{
return
c10
::
Float8_e4m3fn
(
__hip_cvt_float_to_fp8
(
r
,
__hip_fp8_e4m3
::
__default_saturation
,
__hip_fp8_e4m3
::
__default_interpret
),
c10
::
Float8_e4m3fn
::
from_bits
());
}
template
<
>
__device__
__forceinline__
c10
::
Float8_e4m3fnuz
cvt_c10
(
float
const
r
)
{
return
c10
::
Float8_e4m3fnuz
(
__hip_cvt_float_to_fp8
(
r
,
__hip_fp8_e4m3_fnuz
::
__default_saturation
,
__hip_fp8_e4m3_fnuz
::
__default_interpret
),
c10
::
Float8_e4m3fnuz
::
from_bits
());
}
template
<
typename
Tout
,
typename
Tin
>
__inline__
__device__
Tout
vec_conversion
(
const
Tin
&
x
)
{
return
x
;
...
...
csrc/quantization/fp8/common.cu
View file @
a1c8f379
...
...
@@ -11,8 +11,8 @@
namespace
vllm
{
template
<
typename
scalar_t
>
__global__
void
scaled_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
template
<
typename
scalar_t
,
typename
fp8_type
>
__global__
void
scaled_fp8_quant_kernel
(
fp8_type
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
const
float
*
__restrict__
scale
,
int64_t
num_elems
)
{
...
...
@@ -25,12 +25,13 @@ __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
out
,
input
,
inverted_scale
,
num_elems
,
tid
,
blockDim
.
x
*
gridDim
.
x
);
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
typename
fp8_type
>
__global__
void
dynamic_per_token_scaled_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
float
*
__restrict__
scale
,
fp8_type
*
__restrict__
out
,
float
*
__restrict__
scale
,
scalar_t
const
*
__restrict__
input
,
float
const
*
__restrict__
scale_ub
,
const
int
hidden_size
)
{
float
const
min_scaling_factor
=
1.0
f
/
(
FP8_E4M3_MAX
*
512.
f
);
float
const
min_scaling_factor
=
1.0
f
/
(
fp8_e4m3_adjusted_max_v
<
fp8_type
>
*
512.
f
);
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
...
...
@@ -38,7 +39,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
// Use int64 to avoid overflowing an int32 when calculating this offset
int64_t
offset
=
static_cast
<
int64_t
>
(
token_idx
)
*
hidden_size
;
scalar_t
const
*
__restrict__
token_input
=
&
input
[
offset
];
FP8_TYPE
*
__restrict__
token_output
=
&
out
[
offset
];
fp8_type
*
__restrict__
token_output
=
&
out
[
offset
];
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
...
...
@@ -66,7 +67,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
token_scale
=
block_absmax_val_maybe
;
}
// token scale computation
token_scale
=
max
(
token_scale
/
FP8_E4M3_MAX
,
min_scaling_factor
);
token_scale
=
max
(
token_scale
/
fp8_e4m3_adjusted_max_v
<
fp8_type
>
,
min_scaling_factor
);
scale
[
token_idx
]
=
token_scale
;
}
__syncthreads
();
...
...
@@ -77,7 +79,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
token_output
,
token_input
,
token_scale
,
hidden_size
,
tid
,
blockDim
.
x
);
}
else
{
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
token_output
[
i
]
=
scaled_fp8_conversion
<
false
>
(
token_output
[
i
]
=
scaled_fp8_conversion
<
false
,
fp8_type
>
(
static_cast
<
float
>
(
token_input
[
i
]),
token_scale
);
}
}
...
...
@@ -96,10 +98,14 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"scaled_fp8_quant_kernel"
,
[
&
]
{
vllm
::
scaled_fp8_quant_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
FP8_TYPE
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
num_elems
);
input
.
scalar_type
(),
"scaled_fp8_quant_kernel_scalar_type"
,
[
&
]
{
VLLM_DISPATCH_FP8_TYPES
(
out
.
scalar_type
(),
"scaled_fp8_quant_kernel_fp8_type"
,
[
&
]
{
vllm
::
scaled_fp8_quant_kernel
<
scalar_t
,
fp8_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
fp8_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
num_elems
);
});
});
}
...
...
@@ -114,12 +120,18 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"scaled_fp8_quant_kernel"
,
[
&
]
{
vllm
::
segmented_max_reduction
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
scale
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_t
>
(),
num_elems
);
vllm
::
scaled_fp8_quant_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
FP8_TYPE
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
num_elems
);
input
.
scalar_type
(),
"scaled_fp8_quant_kernel_scalar_type"
,
[
&
]
{
VLLM_DISPATCH_FP8_TYPES
(
out
.
scalar_type
(),
"scaled_fp8_quant_kernel_fp8_type"
,
[
&
]
{
vllm
::
segmented_max_reduction
<
scalar_t
,
fp8_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
scale
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_t
>
(),
num_elems
);
vllm
::
scaled_fp8_quant_kernel
<
scalar_t
,
fp8_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
fp8_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
num_elems
);
});
});
}
...
...
@@ -138,12 +150,18 @@ void dynamic_per_token_scaled_fp8_quant(
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"dynamic_per_token_scaled_fp8_quant_kernel"
,
[
&
]
{
vllm
::
dynamic_per_token_scaled_fp8_quant_kernel
<
scalar_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
FP8_TYPE
>
(),
scales
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale_ub
.
has_value
()
?
scale_ub
->
data_ptr
<
float
>
()
:
nullptr
,
hidden_size
);
input
.
scalar_type
(),
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type"
,
[
&
]
{
VLLM_DISPATCH_FP8_TYPES
(
out
.
scalar_type
(),
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type"
,
[
&
]
{
vllm
::
dynamic_per_token_scaled_fp8_quant_kernel
<
scalar_t
,
fp8_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
fp8_t
>
(),
scales
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale_ub
.
has_value
()
?
scale_ub
->
data_ptr
<
float
>
()
:
nullptr
,
hidden_size
);
});
});
}
csrc/quantization/fp8/common.cuh
View file @
a1c8f379
...
...
@@ -7,18 +7,52 @@
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
C10_HOST_DEVICE
constexpr
auto
FP8_E4M3_MAX
=
std
::
numeric_limits
<
FP8_TYPE
>::
max
();
#define MAYBE_HOST_DEVICE C10_HOST_DEVICE
#else
#include <ATen/hip/HIPContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include "amd/quant_utils.cuh"
using
FP8_TYPE
=
c10
::
Float8_e4m3fnuz
;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr
auto
FP8_E4M3_MAX
=
224.0
f
;
// ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr
#define MAYBE_HOST_DEVICE
#endif
// Determines the preferred FP8 type for the current platform.
// Note that for CUDA this just returns true,
// but on ROCm it will check device props.
static
bool
is_fp8_ocp
()
{
#ifndef USE_ROCM
return
true
;
#else
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
std
::
string
device_arch
=
dprops
->
gcnArchName
;
size_t
substring
=
device_arch
.
find
(
"gfx94"
);
return
substring
==
std
::
string
::
npos
;
#endif
constexpr
static
auto
kFp8Type
=
c10
::
CppTypeToScalarType
<
FP8_TYPE
>::
value
;
}
template
<
typename
T
>
struct
fp8_e4m3_adjusted_max
;
template
<
>
struct
fp8_e4m3_adjusted_max
<
c10
::
Float8_e4m3fn
>
{
static
constexpr
c10
::
Float8_e4m3fn
val
()
{
return
std
::
numeric_limits
<
c10
::
Float8_e4m3fn
>::
max
();
}
};
// Using the default max value from pytorch (240.0 0x7F) will cause accuracy
// issues when running dynamic quantization. Here use 224.0 0x7E for rocm.
template
<
>
struct
fp8_e4m3_adjusted_max
<
c10
::
Float8_e4m3fnuz
>
{
static
constexpr
c10
::
Float8_e4m3fnuz
val
()
{
return
c10
::
Float8_e4m3fnuz
(
0x7E
,
c10
::
Float8_e4m3fnuz
::
from_bits
());
}
};
template
<
typename
T
>
MAYBE_HOST_DEVICE
static
constexpr
T
fp8_e4m3_adjusted_max_v
=
fp8_e4m3_adjusted_max
<
T
>::
val
();
namespace
vllm
{
...
...
@@ -32,8 +66,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
return
old
;
}
template
<
bool
is_scale_inverted
>
__device__
__forceinline__
FP8_TYPE
scaled_fp8_conversion
(
float
const
val
,
template
<
bool
is_scale_inverted
,
typename
fp8_type
>
__device__
__forceinline__
fp8_type
scaled_fp8_conversion
(
float
const
val
,
float
const
scale
)
{
float
x
=
0.0
f
;
if
constexpr
(
is_scale_inverted
)
{
...
...
@@ -42,15 +76,13 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
x
=
val
/
scale
;
}
float
r
=
fmax
(
-
FP8_E4M3_MAX
,
fmin
(
x
,
FP8_E4M3_MAX
));
float
r
=
fmax
(
-
fp8_e4m3_adjusted_max_v
<
fp8_type
>
,
fmin
(
x
,
fp8_e4m3_adjusted_max_v
<
fp8_type
>
));
#ifndef USE_ROCM
return
static_cast
<
c10
::
Float8_e4m3fn
>
(
r
);
return
static_cast
<
fp8_type
>
(
r
);
#else
// Use hardware cvt instruction for fp8 on rocm
return
c10
::
Float8_e4m3fnuz
(
__hip_cvt_float_to_fp8
(
r
,
fp8
::
fp8_type
::
__default_saturation
,
fp8
::
fp8_type
::
__default_interpret
),
c10
::
Float8_e4m3fnuz
::
from_bits
());
return
fp8
::
cvt_c10
<
fp8_type
>
(
r
);
#endif
}
...
...
@@ -60,7 +92,7 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
typename
fp8_type
>
__global__
void
segmented_max_reduction
(
float
*
__restrict__
scale
,
const
scalar_t
*
__restrict__
input
,
int64_t
num_elems
)
{
...
...
@@ -91,7 +123,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if
(
threadIdx
.
x
==
0
)
{
atomicMaxFloat
(
scale
,
cache
[
0
]
/
FP8_E4M3_MAX
);
atomicMaxFloat
(
scale
,
cache
[
0
]
/
fp8_e4m3_adjusted_max_v
<
fp8_type
>
);
}
}
...
...
@@ -123,13 +155,13 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
return
absmax_val
;
}
template
<
typename
scalar_t
,
bool
is_scale_inverted
>
__device__
void
scaled_fp8_conversion_vec
(
FP8_TYPE
*
__restrict__
out
,
template
<
typename
scalar_t
,
bool
is_scale_inverted
,
typename
fp8_type
>
__device__
void
scaled_fp8_conversion_vec
(
fp8_type
*
__restrict__
out
,
scalar_t
const
*
__restrict__
input
,
float
const
scale
,
int64_t
const
num_elems
,
int
const
tid
,
int
const
step
)
{
using
float8x4_t
=
q8x4_t
<
FP8_TYPE
>
;
using
float8x4_t
=
q8x4_t
<
fp8_type
>
;
// Vectorized input/output to better utilize memory bandwidth.
auto
const
*
vectorized_in
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
auto
*
vectorized_out
=
reinterpret_cast
<
float8x4_t
*>
(
out
);
...
...
@@ -141,22 +173,22 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
float8x4_t
out_vec
;
out_vec
.
x
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
out_vec
.
x
=
scaled_fp8_conversion
<
is_scale_inverted
,
fp8_type
>
(
static_cast
<
float
>
(
in_vec
.
x
),
scale
);
out_vec
.
y
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
out_vec
.
y
=
scaled_fp8_conversion
<
is_scale_inverted
,
fp8_type
>
(
static_cast
<
float
>
(
in_vec
.
y
),
scale
);
out_vec
.
z
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
out_vec
.
z
=
scaled_fp8_conversion
<
is_scale_inverted
,
fp8_type
>
(
static_cast
<
float
>
(
in_vec
.
z
),
scale
);
out_vec
.
w
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
out_vec
.
w
=
scaled_fp8_conversion
<
is_scale_inverted
,
fp8_type
>
(
static_cast
<
float
>
(
in_vec
.
w
),
scale
);
vectorized_out
[
i
]
=
out_vec
;
}
// Handle the remaining elements if num_elems is not divisible by 4
for
(
int64_t
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
out
[
i
]
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
out
[
i
]
=
scaled_fp8_conversion
<
is_scale_inverted
,
fp8_type
>
(
static_cast
<
float
>
(
input
[
i
]),
scale
);
}
}
}
// namespace vllm
\ No newline at end of file
}
// namespace vllm
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
View file @
a1c8f379
...
...
@@ -144,6 +144,9 @@ void rms_norm_dynamic_per_token_quant(
torch
::
Tensor
&
scales
,
// [num_tokens]
double
const
var_epsilon
,
// Variance epsilon used in norm calculation
std
::
optional
<
at
::
Tensor
>
scale_ub
,
std
::
optional
<
at
::
Tensor
>
residual
)
{
static
c10
::
ScalarType
kFp8Type
=
is_fp8_ocp
()
?
c10
::
ScalarType
::
Float8_e4m3fn
:
c10
::
ScalarType
::
Float8_e4m3fnuz
;
TORCH_CHECK
(
out
.
dtype
()
==
kFp8Type
||
out
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
out
.
is_contiguous
()
&&
input
.
is_contiguous
());
...
...
csrc/quantization/fused_kernels/quant_conversions.cuh
View file @
a1c8f379
...
...
@@ -31,9 +31,11 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
#endif
}
static
__device__
__forceinline__
FP8_TYPE
float_to_fp8
(
float
const
x
)
{
float
const
r
=
fmax
(
-
FP8_E4M3_MAX
,
fmin
(
x
,
FP8_E4M3_MAX
));
return
static_cast
<
FP8_TYPE
>
(
r
);
template
<
typename
fp8_type
>
static
__device__
__forceinline__
fp8_type
float_to_fp8
(
float
const
x
)
{
float
const
r
=
fmax
(
-
fp8_e4m3_adjusted_max_v
<
fp8_type
>
,
fmin
(
x
,
fp8_e4m3_adjusted_max_v
<
fp8_type
>
));
return
static_cast
<
fp8_type
>
(
r
);
}
template
<
typename
quant_type_t
,
bool
is_scale_inverted
,
typename
enable
=
void
>
...
...
@@ -54,15 +56,16 @@ struct ScaledQuant<
};
template
<
typename
quant_type_t
,
bool
is_scale_inverted
>
struct
ScaledQuant
<
quant_type_t
,
is_scale_inverted
,
typename
std
::
enable_if_t
<
std
::
is_same_v
<
quant_type_t
,
FP8_TYPE
>>>
{
struct
ScaledQuant
<
quant_type_t
,
is_scale_inverted
,
typename
std
::
enable_if_t
<
std
::
is_same_v
<
quant_type_t
,
c10
::
Float8_e4m3fn
>
||
std
::
is_same_v
<
quant_type_t
,
c10
::
Float8_e4m3fnuz
>>>
{
static
__device__
__forceinline__
quant_type_t
quant_fn
(
float
const
x
,
float
const
scale
)
{
if
constexpr
(
is_scale_inverted
)
{
return
float_to_fp8
(
x
*
scale
);
return
float_to_fp8
<
quant_type_t
>
(
x
*
scale
);
}
else
{
return
float_to_fp8
(
x
/
scale
);
return
float_to_fp8
<
quant_type_t
>
(
x
/
scale
);
}
}
};
...
...
csrc/quantization/vectorization.cuh
View file @
a1c8f379
...
...
@@ -4,7 +4,6 @@
*/
// Include both AMD and NVIDIA fp8 types to avoid circular import
// TODO(luka/varun) use FP8_TYPE instead after refactoring
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e4m3fn.h>
...
...
tests/kernels/quant_utils.py
View file @
a1c8f379
...
...
@@ -9,8 +9,7 @@ from vllm.platforms import current_platform
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8_MAX
=
224.0
FP8_DTYPE
=
torch
.
float8_e4m3fnuz
if
current_platform
.
is_rocm
()
\
else
torch
.
float8_e4m3fn
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
def
as_float32_tensor
(
x
:
Union
[
float
,
torch
.
tensor
])
->
torch
.
tensor
:
...
...
tests/kernels/test_triton_scaled_mm.py
View file @
a1c8f379
...
...
@@ -32,11 +32,8 @@ def scaled_mm_torch(a: torch.Tensor,
def
get_8bit_types
():
types
=
[
torch
.
int8
]
supports_fp8
=
current_platform
.
has_device_capability
(
89
)
if
current_platform
.
is_rocm
()
and
supports_fp8
:
types
.
append
(
torch
.
float8_e4m3fnuz
)
elif
current_platform
.
is_cuda
()
and
supports_fp8
:
types
.
append
(
torch
.
float8_e4m3fn
)
if
current_platform
.
supports_fp8
():
types
.
append
(
current_platform
.
fp8_dtype
())
return
types
...
...
tests/quantization/test_fp8.py
View file @
a1c8f379
...
...
@@ -103,8 +103,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
assert
attn
.
_v_scale
==
1.0
if
current_platform
.
is_cuda
():
if
current_platform
.
has_device_capability
(
89
)
and
not
force_marlin
:
if
current_platform
.
supports_fp8
()
and
not
force_marlin
:
# For GPUs with hardware support, we keep weights in fp8
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
else
:
...
...
@@ -112,11 +111,9 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
# for weight-only quantization using Marlin kernels
assert
fc1
.
weight
.
dtype
==
torch
.
int32
elif
current_platform
.
is_rocm
():
# Only MI300 and above support quantization='fp8'
if
current_platform
.
has_device_capability
(
94
)
and
not
force_marlin
:
if
current_platform
.
supports_fp8
()
and
not
force_marlin
:
# For GPUs with hardware support, we keep weights in fp8
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fnuz
assert
fc1
.
weight
.
dtype
==
current_platform
.
fp8_dtype
()
else
:
# unsupported ROCm platform
pytest
.
skip
(
"Skip `test_load_fp16_model`. "
...
...
vllm/_custom_ops.py
View file @
a1c8f379
...
...
@@ -478,16 +478,16 @@ def cutlass_scaled_mm(a: torch.Tensor,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
`cutlass_scaled_mm` implements a fused version of
`cutlass_scaled_mm` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
In order to support blockwise scaling like found in DeepSeek V3 we also
support extended "group" broadcast rules. We extend the numpy-style
broadcasting rules with the following rule:
"if the extent of a dimension in the source shape is between 1 and
corresponding extent in the target shape we repeat each element along
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
In order to support blockwise scaling like found in DeepSeek V3 we also
support extended "group" broadcast rules. We extend the numpy-style
broadcasting rules with the following rule:
"if the extent of a dimension in the source shape is between 1 and
corresponding extent in the target shape we repeat each element along
that dimension src_shape[dim] // target_shape[dim] times consecutively"
example if we have:
a = [[1, 2], and target_shape = (2, 4)
...
...
@@ -564,7 +564,7 @@ def cutlass_sparse_compress(a: torch.Tensor) \
with Cutlass sparse kernels.
Args:
a (torch.Tensor):
a (torch.Tensor):
The input tensor to be compressed. Must have one of the following data types:
- `torch.int8`
- `torch.float8_e4m3fn`
...
...
@@ -572,7 +572,7 @@ def cutlass_sparse_compress(a: torch.Tensor) \
- `torch.float16`
Returns:
tuple[torch.Tensor, torch.Tensor]:
tuple[torch.Tensor, torch.Tensor]:
A tuple containing:
- `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`.
- `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation.
...
...
@@ -875,9 +875,8 @@ def scaled_fp8_quant(
# This code assumes batch_dim and num_tokens are flattened
assert
(
input
.
ndim
==
2
)
shape
:
Union
[
tuple
[
int
,
int
],
torch
.
Size
]
=
input
.
shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype
:
torch
.
dtype
=
torch
.
float8_e4m3fnuz
\
if
current_platform
.
is_rocm
()
else
torch
.
float8_e4m3fn
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype
:
torch
.
dtype
=
current_platform
.
fp8_dtype
()
if
num_token_padding
:
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
])
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
out_dtype
)
...
...
@@ -908,7 +907,7 @@ def allspark_repack_weight(
has_zp
:
bool
=
False
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format
Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format
for Ampere W8A16 Fused Gemm kernel
Args:
...
...
@@ -917,10 +916,10 @@ def allspark_repack_weight(
zero_point: fp16/bf16 weight zero_point tensor, 1 x n format.
Must be provided for asymmetric quantization.
has_zp: if use symmetric quantization, has_zp = False.
if use asymmetric quantization, has_zp = True.
if use asymmetric quantization, has_zp = True.
Returns:
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] :
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] :
rearranged weight, scale, and optionally zero_point.
"""
K
=
qweight
.
shape
[
0
]
...
...
vllm/attention/backends/mla/common.py
View file @
a1c8f379
...
...
@@ -226,7 +226,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8
)
from
vllm.model_executor.layers.quantization.fp8
import
Fp8LinearMethod
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
Fp8LinearGenericOp
,
current_platform_fp8_dtype
,
is_fp8
)
Fp8LinearGenericOp
,
is_fp8
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
scaled_quantize
)
from
vllm.model_executor.layers.rotary_embedding
import
(
...
...
@@ -1238,7 +1238,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
W_Q_UK
,
W_Q_UK_scales
=
scaled_quantize
(
W_Q_UK
,
self
.
reqaunt_weight_group_shape
,
quant_dtype
=
current_platform
_
fp8_dtype
)
quant_dtype
=
current_platform
.
fp8_dtype
()
)
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self
.
W_Q_UK
=
W_Q_UK
.
T
.
contiguous
()
...
...
@@ -1255,7 +1255,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
W_UV_O
,
W_UV_O_scales
=
scaled_quantize
(
W_UV_O
,
self
.
reqaunt_weight_group_shape
,
quant_dtype
=
current_platform
_
fp8_dtype
)
quant_dtype
=
current_platform
.
fp8_dtype
()
)
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self
.
W_UV_O
=
W_UV_O
.
T
.
contiguous
()
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
a1c8f379
...
...
@@ -158,8 +158,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# If rocm, normalize the weights and scales to e4m3fnuz
if
current_platform
.
is_rocm
():
if
current_platform
.
is_fp8_fnuz
():
# Normalize the weights and scales
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
a1c8f379
...
...
@@ -42,7 +42,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
logical_widths
=
layer
.
logical_widths
,
)
if
current_platform
.
is_
rocm
():
if
current_platform
.
is_
fp8_fnuz
():
input_scale
=
getattr
(
layer
,
'input_scale'
,
None
)
weight
,
max_w_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
...
...
@@ -60,7 +60,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight
=
layer
.
weight
if
current_platform
.
is_
rocm
():
if
current_platform
.
is_
fp8_fnuz
():
input_scale
=
getattr
(
layer
,
'input_scale'
,
None
)
weight
,
weight_scale
,
input_scale
=
\
...
...
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
a1c8f379
...
...
@@ -127,7 +127,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
weight
=
layer
.
weight
if
current_platform
.
is_
rocm
():
if
current_platform
.
is_
fp8_fnuz
():
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
a1c8f379
...
...
@@ -270,7 +270,7 @@ class Fp8LinearMethod(LinearMethodBase):
# TODO(rob): refactor block quant into separate class.
if
self
.
block_quant
:
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
if
current_platform
.
is_
rocm
():
if
current_platform
.
is_
fp8_fnuz
():
weight
,
weight_scale_inv
,
_
=
\
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
layer
.
weight
,
...
...
@@ -327,8 +327,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
# If rocm, use float8_e4m3fnuz.
if
current_platform
.
is_rocm
():
if
current_platform
.
is_fp8_fnuz
():
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
...
...
@@ -533,7 +532,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# TODO (rob): refactor block quant into separate class.
if
self
.
block_quant
:
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
if
current_platform
.
is_
rocm
():
if
current_platform
.
is_
fp8_fnuz
():
w13_weight
,
w13_weight_scale_inv
,
w13_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w13_weight
,
layer
.
w13_weight_scale_inv
,
...
...
@@ -559,9 +558,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# If checkpoint is fp16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype
=
torch
.
float8_e4m3fnuz
\
if
current_platform
.
is_rocm
()
else
torch
.
float8_e4m3fn
fp8_dtype
=
current_platform
.
fp8_dtype
()
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
fp8_dtype
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
fp8_dtype
)
...
...
@@ -608,8 +605,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# If rocm, normalize the weights and scales to e4m3fnuz
if
current_platform
.
is_rocm
():
if
current_platform
.
is_fp8_fnuz
():
# Normalize the weights and scales
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
...
...
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
a1c8f379
...
...
@@ -142,8 +142,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# If rocm, normalize the weights and scales to e4m3fnuz
if
current_platform
.
is_rocm
():
if
current_platform
.
is_fp8_fnuz
():
# Normalize the weights and scales
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
...
...
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
View file @
a1c8f379
...
...
@@ -39,7 +39,7 @@ class QuarkW8A8Fp8(QuarkScheme):
logical_widths
=
layer
.
logical_widths
,
)
if
current_platform
.
is_
rocm
():
if
current_platform
.
is_
fp8_fnuz
():
weight
,
max_w_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
max_w_scale
,
...
...
@@ -55,7 +55,7 @@ class QuarkW8A8Fp8(QuarkScheme):
elif
self
.
qscheme
==
"per_channel"
:
weight
=
layer
.
weight
if
current_platform
.
is_
rocm
():
if
current_platform
.
is_
fp8_fnuz
():
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment