Commit a59531f8 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.11.0-dev-Q' into 'v0.11.0-dev'

V0.11.0 dev q

See merge request dcutoolkit/deeplearing/vllm!392
parents 0289bb5b 1fb40bd3
......@@ -266,7 +266,7 @@ set(VLLM_EXT_SRC
"csrc/cuda_view.cu"
# "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
# "csrc/quantization/fp8/common.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
# "csrc/quantization/activation_kernels.cu"
......
......@@ -318,8 +318,8 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
// void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor const& scale);
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale);
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor& scale);
......
......@@ -47,15 +47,20 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
x = val / scale;
}
float r =
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
// float r =
// fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
#ifndef USE_ROCM
// Use hardware cvt instruction for fp8 on nvidia
// Currently only support fp8_type = c10::Float8_e4m3fn
return fp8::vec_conversion<fp8_type, float>(r);
#else
fp8_type *test;
uint8_t test_uint8 = fp8::float_to_fp8_e4m3(x);
test = (fp8_type*)(&test_uint8);
return *test;
// Use hardware cvt instruction for fp8 on rocm
return fp8::cvt_c10<fp8_type>(r);
// return fp8::cvt_c10<fp8_type>(r);
#endif
}
......
......@@ -601,6 +601,27 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// "()");
// ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
// ops.def(
// "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
// "Tensor! scale, Tensor? scale_ub) -> "
// "()");
// ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
// &dynamic_per_token_scaled_fp8_quant);
// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
"()");
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
// ops.def(
// "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
// "-> "
// "()");
// ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
// ops.def(
// "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
......
......@@ -258,7 +258,7 @@ class Attention(nn.Module, AttentionLayerBase):
# @TODO
if envs.VLLM_USE_QUERY_QUANT:
if self.kv_cache_dtype.startswith(
"fp8") and self.attn_backend.supports_quant_query_input:
"fp8") and self.attn_backend.supports_quant_query_input:
self.query_quant = QuantFP8(static=True,
group_shape=GroupShape.PER_TENSOR)
......@@ -303,11 +303,11 @@ class Attention(nn.Module, AttentionLayerBase):
if output_shape is not None else query.shape)
if envs.VLLM_USE_OPT_ZEROS:
output = torch.empty(output_shape,
dtype=query.dtype,
dtype=output_dtype,
device=query.device)
else:
output = torch.zeros(output_shape,
dtype=query.dtype,
dtype=output_dtype,
device=query.device)
hidden_size = output_shape[-1]
# We skip reshaping query, key and value tensors for the MLA
......
......@@ -646,7 +646,7 @@ class FlashAttentionImpl(AttentionImpl):
scheduler_metadata=scheduler_metadata,
# fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape),
q_descale=None,
q_descale=layer._q_scale,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
# num_splits=attn_metadata.max_num_splits,
......@@ -674,7 +674,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
fa_version=self.vllm_flash_attn_version,
# fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
q_descale=layer._q_scale,
......@@ -699,11 +699,10 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
fa_version=2, #self.vllm_flash_attn_version,
# fa_version=2, #self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
# q_descale=layer._q_scale,
q_descale=None,
q_descale=layer._q_scale,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
)
......@@ -783,7 +782,7 @@ class FlashAttentionImpl(AttentionImpl):
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
q_descale=None,
q_descale=layer._q_scale,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
is_prefix_cache=False,
......@@ -914,7 +913,7 @@ def cascade_attention(
softcap=logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata,
fa_version=fa_version,
# fa_version=fa_version,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
......@@ -967,7 +966,7 @@ def cascade_attention(
softcap=logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata,
fa_version=fa_version,
# fa_version=fa_version,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment