Commit 4661cd18 authored by zhangqha's avatar zhangqha
Browse files

Merge branch 'v0.15.1-dev-fth' into 'v0.15.1-dev'

修复channel-int8 的block_shape读取bug

See merge request dcutoolkit/deeplearing/vllm!462
parents e962f483 3af22744
...@@ -71,7 +71,8 @@ __device__ inline bool cmp_eq(const T& a, const T& b) { ...@@ -71,7 +71,8 @@ __device__ inline bool cmp_eq(const T& a, const T& b) {
// Fixed constants common to both dynamic and static template versions: // Fixed constants common to both dynamic and static template versions:
static constexpr int SIZE_WARP = 32; static constexpr int SIZE_WARP = 32;
static constexpr int WARPS_PER_CTA = 6; static constexpr int WARPS_PER_CTA = 6;
static constexpr int MAX_VPT = 32; // maximum VPT we support, > params.VPT = num_expert / num_expert_group // static constexpr int MAX_VPT = 32; // maximum VPT we support, > params.VPT = num_expert / num_expert_group
static constexpr int MAX_VPT = 128; // Extend MAX_VPT from 32 to 128 to accommodate large-scale MoE models (e.g., GLM-4V-quantized model).
// Create an alias for Array using AlignedArray // Create an alias for Array using AlignedArray
template <typename T, int N> template <typename T, int N>
......
...@@ -246,8 +246,8 @@ class FusedMoEQuantConfig: ...@@ -246,8 +246,8 @@ class FusedMoEQuantConfig:
@property @property
def block_shape(self) -> list[int] | None: def block_shape(self) -> list[int] | None:
if self.use_int8_w8a8: # if self.use_int8_w8a8:
return [256, 256] # return [256, 256]
if ( if (
self._a1.shape is not None self._a1.shape is not None
...@@ -572,7 +572,7 @@ def int8_w8a8_moe_quant_config( ...@@ -572,7 +572,7 @@ def int8_w8a8_moe_quant_config(
a2_scale=a2_scale, a2_scale=a2_scale,
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
per_out_ch_quant=False, per_out_ch_quant=False,
block_shape=[256, 256], block_shape=None,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment