Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a59531f8
Commit
a59531f8
authored
Jan 27, 2026
by
zhuwenwen
Browse files
Merge branch 'v0.11.0-dev-Q' into 'v0.11.0-dev'
V0.11.0 dev q See merge request dcutoolkit/deeplearing/vllm!392
parents
0289bb5b
1fb40bd3
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
42 additions
and
17 deletions
+42
-17
CMakeLists.txt
CMakeLists.txt
+1
-1
csrc/ops.h
csrc/ops.h
+2
-2
csrc/quantization/fp8/common.cuh
csrc/quantization/fp8/common.cuh
+8
-3
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+21
-0
vllm/attention/layer.py
vllm/attention/layer.py
+3
-3
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+7
-8
No files found.
CMakeLists.txt
View file @
a59531f8
...
@@ -266,7 +266,7 @@ set(VLLM_EXT_SRC
...
@@ -266,7 +266,7 @@ set(VLLM_EXT_SRC
"csrc/cuda_view.cu"
"csrc/cuda_view.cu"
# "csrc/quantization/gptq/q_gemm.cu"
# "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
#
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
# "csrc/quantization/activation_kernels.cu"
# "csrc/quantization/activation_kernels.cu"
...
...
csrc/ops.h
View file @
a59531f8
...
@@ -318,8 +318,8 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
...
@@ -318,8 +318,8 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
// void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
// void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
//
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
//
torch::Tensor const& scale);
torch
::
Tensor
const
&
scale
);
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor& scale);
// torch::Tensor& scale);
...
...
csrc/quantization/fp8/common.cuh
View file @
a59531f8
...
@@ -47,15 +47,20 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
...
@@ -47,15 +47,20 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
x
=
val
/
scale
;
x
=
val
/
scale
;
}
}
float
r
=
//
float r =
fmaxf
(
-
quant_type_max_v
<
fp8_type
>
,
fminf
(
x
,
quant_type_max_v
<
fp8_type
>
));
//
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
#ifndef USE_ROCM
#ifndef USE_ROCM
// Use hardware cvt instruction for fp8 on nvidia
// Use hardware cvt instruction for fp8 on nvidia
// Currently only support fp8_type = c10::Float8_e4m3fn
// Currently only support fp8_type = c10::Float8_e4m3fn
return
fp8
::
vec_conversion
<
fp8_type
,
float
>
(
r
);
return
fp8
::
vec_conversion
<
fp8_type
,
float
>
(
r
);
#else
#else
fp8_type
*
test
;
uint8_t
test_uint8
=
fp8
::
float_to_fp8_e4m3
(
x
);
test
=
(
fp8_type
*
)(
&
test_uint8
);
return
*
test
;
// Use hardware cvt instruction for fp8 on rocm
// Use hardware cvt instruction for fp8 on rocm
return
fp8
::
cvt_c10
<
fp8_type
>
(
r
);
//
return fp8::cvt_c10<fp8_type>(r);
#endif
#endif
}
}
...
...
csrc/torch_bindings.cpp
View file @
a59531f8
...
@@ -601,6 +601,27 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -601,6 +601,27 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// "()");
// "()");
// ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
// ops.def(
// "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
// "Tensor! scale, Tensor? scale_ub) -> "
// "()");
// ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
// &dynamic_per_token_scaled_fp8_quant);
// Compute int8 quantized tensor for given scaling factor.
ops
.
def
(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
"()"
);
ops
.
impl
(
"static_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
static_scaled_fp8_quant
);
// // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
// ops.def(
// "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
// "-> "
// "()");
// ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
// // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
// ops.def(
// ops.def(
// "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
// "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
...
...
vllm/attention/layer.py
View file @
a59531f8
...
@@ -258,7 +258,7 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -258,7 +258,7 @@ class Attention(nn.Module, AttentionLayerBase):
# @TODO
# @TODO
if
envs
.
VLLM_USE_QUERY_QUANT
:
if
envs
.
VLLM_USE_QUERY_QUANT
:
if
self
.
kv_cache_dtype
.
startswith
(
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
and
self
.
attn_backend
.
supports_quant_query_input
:
"fp8"
)
and
self
.
attn_backend
.
supports_quant_query_input
:
self
.
query_quant
=
QuantFP8
(
static
=
True
,
self
.
query_quant
=
QuantFP8
(
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
)
group_shape
=
GroupShape
.
PER_TENSOR
)
...
@@ -303,11 +303,11 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -303,11 +303,11 @@ class Attention(nn.Module, AttentionLayerBase):
if
output_shape
is
not
None
else
query
.
shape
)
if
output_shape
is
not
None
else
query
.
shape
)
if
envs
.
VLLM_USE_OPT_ZEROS
:
if
envs
.
VLLM_USE_OPT_ZEROS
:
output
=
torch
.
empty
(
output_shape
,
output
=
torch
.
empty
(
output_shape
,
dtype
=
query
.
dtype
,
dtype
=
output_
dtype
,
device
=
query
.
device
)
device
=
query
.
device
)
else
:
else
:
output
=
torch
.
zeros
(
output_shape
,
output
=
torch
.
zeros
(
output_shape
,
dtype
=
query
.
dtype
,
dtype
=
output_
dtype
,
device
=
query
.
device
)
device
=
query
.
device
)
hidden_size
=
output_shape
[
-
1
]
hidden_size
=
output_shape
[
-
1
]
# We skip reshaping query, key and value tensors for the MLA
# We skip reshaping query, key and value tensors for the MLA
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
a59531f8
...
@@ -646,7 +646,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -646,7 +646,7 @@ class FlashAttentionImpl(AttentionImpl):
scheduler_metadata
=
scheduler_metadata
,
scheduler_metadata
=
scheduler_metadata
,
# fa_version=self.vllm_flash_attn_version,
# fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape),
# q_descale=layer._q_scale.expand(descale_shape),
q_descale
=
Non
e
,
q_descale
=
layer
.
_q_scal
e
,
k_descale
=
layer
.
_k_scale
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
v_descale
=
layer
.
_v_scale
,
# num_splits=attn_metadata.max_num_splits,
# num_splits=attn_metadata.max_num_splits,
...
@@ -674,7 +674,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -674,7 +674,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap
=
self
.
logits_soft_cap
,
logits_soft_cap
=
self
.
logits_soft_cap
,
block_table
=
attn_metadata
.
block_table
,
block_table
=
attn_metadata
.
block_table
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
fa_version
=
self
.
vllm_flash_attn_version
,
#
fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata
=
attn_metadata
.
prefix_scheduler_metadata
,
prefix_scheduler_metadata
=
attn_metadata
.
prefix_scheduler_metadata
,
suffix_scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
suffix_scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
q_descale
=
layer
.
_q_scale
,
q_descale
=
layer
.
_q_scale
,
...
@@ -699,11 +699,10 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -699,11 +699,10 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap
=
self
.
logits_soft_cap
,
logits_soft_cap
=
self
.
logits_soft_cap
,
block_table
=
attn_metadata
.
block_table
,
block_table
=
attn_metadata
.
block_table
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
fa_version
=
2
,
#self.vllm_flash_attn_version,
#
fa_version=2, #self.vllm_flash_attn_version,
prefix_scheduler_metadata
=
attn_metadata
.
prefix_scheduler_metadata
,
prefix_scheduler_metadata
=
attn_metadata
.
prefix_scheduler_metadata
,
suffix_scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
suffix_scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
# q_descale=layer._q_scale,
q_descale
=
layer
.
_q_scale
,
q_descale
=
None
,
k_descale
=
layer
.
_k_scale
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
v_descale
=
layer
.
_v_scale
,
)
)
...
@@ -783,7 +782,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -783,7 +782,7 @@ class FlashAttentionImpl(AttentionImpl):
# q_descale=layer._q_scale.expand(descale_shape),
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
q_descale
=
Non
e
,
q_descale
=
layer
.
_q_scal
e
,
k_descale
=
layer
.
_k_scale
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
v_descale
=
layer
.
_v_scale
,
is_prefix_cache
=
False
,
is_prefix_cache
=
False
,
...
@@ -914,7 +913,7 @@ def cascade_attention(
...
@@ -914,7 +913,7 @@ def cascade_attention(
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
scheduler_metadata
=
prefix_scheduler_metadata
,
scheduler_metadata
=
prefix_scheduler_metadata
,
fa_version
=
fa_version
,
#
fa_version=fa_version,
q_descale
=
q_descale
.
expand
(
descale_shape
)
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
k_descale
=
k_descale
.
expand
(
descale_shape
)
...
@@ -967,7 +966,7 @@ def cascade_attention(
...
@@ -967,7 +966,7 @@ def cascade_attention(
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
scheduler_metadata
=
suffix_scheduler_metadata
,
scheduler_metadata
=
suffix_scheduler_metadata
,
fa_version
=
fa_version
,
#
fa_version=fa_version,
q_descale
=
q_descale
.
expand
(
descale_shape
)
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
k_descale
=
k_descale
.
expand
(
descale_shape
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment