Commit 58bbb720 authored by zhuwenwen's avatar zhuwenwen
Browse files

[fix]fix tests of fused_moe

parent 4a946680
...@@ -17,8 +17,10 @@ from vllm.utils import get_max_shared_memory_bytes ...@@ -17,8 +17,10 @@ from vllm.utils import get_max_shared_memory_bytes
if not current_platform.is_rocm(): if not current_platform.is_rocm():
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm.attention.backends.xformers import _make_alibi_bias
from vllm.attention.backends.xformers import _make_alibi_bias
if current_platform.is_rocm():
from flash_attn import vllm_flash_attn_with_kvcache
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability. # This will change depending on the compute capability.
...@@ -223,7 +225,6 @@ def test_paged_attention( ...@@ -223,7 +225,6 @@ def test_paged_attention(
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0] cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0])) and block_size == BLOCK_SIZES[0]))
elif version in ("v2", "rocm"): elif version in ("v2", "rocm"):
if current_platform.is_rocm() and version == "rocm": if current_platform.is_rocm() and version == "rocm":
PARTITION_SIZE = PARTITION_SIZE_ROCM PARTITION_SIZE = PARTITION_SIZE_ROCM
...@@ -268,7 +269,7 @@ def test_paged_attention( ...@@ -268,7 +269,7 @@ def test_paged_attention(
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0] cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0])) and block_size == BLOCK_SIZES[0]))
else: else:
ops.paged_attention_rocm( ops.paged_attention_rocm(
output, output,
......
...@@ -226,10 +226,10 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, ...@@ -226,10 +226,10 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
rtol=1e-3) rtol=1e-3)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16]) # @pytest.mark.parametrize("itype", [torch.float32, torch.float16])
@pytest.mark.parametrize("n_heads", [4, 8, 13]) # @pytest.mark.parametrize("n_heads", [4, 8, 13])
@pytest.mark.parametrize("d_head", [5, 16, 21, 32]) # @pytest.mark.parametrize("d_head", [5, 16, 21, 32])
@pytest.mark.parametrize( # @pytest.mark.parametrize(
"seq_len_chunk_size_cases", "seq_len_chunk_size_cases",
[ [
...@@ -255,56 +255,56 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, ...@@ -255,56 +255,56 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
(64, 256, 2, [(5, 30), (1, 2), (1, 2), (64, 256, 2, [(5, 30), (1, 2), (1, 2),
(1, 2)]), # irregular sizes with small sequences (1, 2)]), # irregular sizes with small sequences
]) ])
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, # def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
itype): # itype):
# this test with multiple examples in a continuous batch # # this test with multiple examples in a continuous batch
# (i.e. chunked prefill) # # (i.e. chunked prefill)
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases # seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
# hold state during the cutting process so we know if an # # hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle # # example has been exhausted and needs to cycle
last_taken: dict = {} # map: eg -> pointer to last taken sample # last_taken: dict = {} # map: eg -> pointer to last taken sample
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted # exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
states = None # states = None
for Y_min, cu_seqlens, seq_idx, ( # for Y_min, cu_seqlens, seq_idx, (
A, dt, X, B, C) in generate_continuous_batched_examples( # A, dt, X, B, C) in generate_continuous_batched_examples(
cases, num_examples, seqlen, last_taken, exhausted, n_heads, # cases, num_examples, seqlen, last_taken, exhausted, n_heads,
d_head, itype): # d_head, itype):
chunk_indices, chunk_offsets = \ # chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets( # _query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1]) # cu_seqlens, chunk_size, cu_seqlens[-1])
Y, new_states = mamba_chunk_scan_combined( # Y, new_states = mamba_chunk_scan_combined(
X, # X,
dt, # dt,
A, # A,
B, # B,
C, # C,
chunk_size, # chunk_size,
D=None, # D=None,
cu_seqlens=cu_seqlens, # cu_seqlens=cu_seqlens,
seq_idx=seq_idx, # seq_idx=seq_idx,
chunk_indices=chunk_indices, # chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets, # chunk_offsets=chunk_offsets,
return_varlen_states=True, # return_varlen_states=True,
initial_states=states, # initial_states=states,
) # )
# just test the last in sequence # # just test the last in sequence
for i in range(num_examples): # for i in range(num_examples):
# just test one dim and dstate # # just test one dim and dstate
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] # Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
Y_min_eg = Y_min[i][:, 0, 0] # Y_min_eg = Y_min[i][:, 0, 0]
torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3) # torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3)
# update states # # update states
states = new_states # states = new_states
for i, clear in exhausted.items(): # for i, clear in exhausted.items():
if clear: # if clear:
states[i].fill_(0.) # states[i].fill_(0.)
exhausted[i] = False # exhausted[i] = False
...@@ -174,6 +174,7 @@ def test_fused_moe( ...@@ -174,6 +174,7 @@ def test_fused_moe(
use_int8_w8a8=False, use_int8_w8a8=False,
use_int8_w8a16=False, use_int8_w8a16=False,
use_int4_w4a16=False, use_int4_w4a16=False,
use_int4_w4a8=False,
per_act_token_quant=False, per_act_token_quant=False,
block_shape=None) block_shape=None)
...@@ -332,6 +333,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, ...@@ -332,6 +333,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
renormalize=False, renormalize=False,
use_int4_w4a16=weight_bits == 4, use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8, use_int8_w8a16=weight_bits == 8,
use_int4_w4a8=weight_bits == 4,
global_num_experts=e, global_num_experts=e,
expert_map=e_map, expert_map=e_map,
w1_scale=w1_scales, w1_scale=w1_scales,
...@@ -394,12 +396,19 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, ...@@ -394,12 +396,19 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
).cuda() ).cuda()
# Load the weights # Load the weights
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data if not current_platform.is_rocm():
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
else:
vllm_moe.gate.weight.data[:] = (hf_moe.gate.weight.data).T
for i in range(config.num_local_experts): for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data, weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data) hf_moe.experts[i].w3.weight.data)
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) if not current_platform.is_rocm():
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
else:
vllm_moe.experts.w13_weight[i][:] = (torch.cat(weights, dim=0)).T
vllm_moe.experts.w2_weight[i][:] = (hf_moe.experts[i].w2.weight.data).T
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim] # Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs = torch.randn( hf_inputs = torch.randn(
......
...@@ -50,6 +50,7 @@ def get_config_quant_dtype( ...@@ -50,6 +50,7 @@ def get_config_quant_dtype(
use_int8_w8a8: bool, use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
use_int4_w4a8: bool,
) -> Optional[torch.dtype]: ) -> Optional[torch.dtype]:
if use_fp8_w8a8: if use_fp8_w8a8:
return torch.float8_e4m3fn return torch.float8_e4m3fn
...@@ -126,6 +127,7 @@ class FusedMoEQuantConfig: ...@@ -126,6 +127,7 @@ class FusedMoEQuantConfig:
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool = False,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
per_out_ch_quant: bool = False, per_out_ch_quant: bool = False,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
...@@ -136,6 +138,7 @@ class FusedMoEQuantConfig: ...@@ -136,6 +138,7 @@ class FusedMoEQuantConfig:
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16, use_int4_w4a16,
use_int4_w4a8,
] ]
]) <= 1, "Quantization flags are mutually exclusive." ]) <= 1, "Quantization flags are mutually exclusive."
...@@ -144,6 +147,7 @@ class FusedMoEQuantConfig: ...@@ -144,6 +147,7 @@ class FusedMoEQuantConfig:
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8,
) )
return FusedMoEQuantConfig( return FusedMoEQuantConfig(
quant_dtype, quant_dtype,
......
...@@ -1603,7 +1603,8 @@ def fused_experts_impl( ...@@ -1603,7 +1603,8 @@ def fused_experts_impl(
qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8, qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16) use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8)
get_config_func = functools.partial( get_config_func = functools.partial(
try_get_optimal_moe_config, try_get_optimal_moe_config,
...@@ -1877,7 +1878,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1877,7 +1878,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool =False, use_int4_w4a8: bool = False,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
): ):
...@@ -1896,7 +1897,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1896,7 +1897,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.use_int4_w4a16 = use_int4_w4a16 self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a8 = use_int8_w8a8 self.use_int8_w8a8 = use_int8_w8a8
self.use_int8_w8a16 = use_int8_w8a16 self.use_int8_w8a16 = use_int8_w8a16
self.use_int4_w4a8= use_int4_w4a8 self.use_int4_w4a8 = use_int4_w4a8
@property @property
def activation_formats( def activation_formats(
...@@ -2016,6 +2017,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2016,6 +2017,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
w1_scale, w1_scale,
w1_zp, w1_zp,
None, None,
topk_ids,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
...@@ -2027,7 +2029,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2027,7 +2029,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=self.use_int8_w8a8, use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16, use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16, use_int4_w4a16=self.use_int4_w4a16,
use_int4_w4a8= self.use_int4_w4a8, use_int4_w4a8=self.use_int4_w4a8,
per_channel_quant=self.per_act_token_quant, per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape) block_shape=self.block_shape)
...@@ -2047,6 +2049,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2047,6 +2049,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2_scale, w2_scale,
w2_zp, w2_zp,
None, None,
topk_ids,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
...@@ -2068,7 +2071,7 @@ def modular_triton_fused_moe( ...@@ -2068,7 +2071,7 @@ def modular_triton_fused_moe(
use_int8_w8a8: bool, use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
use_int4_w4a8:bool, use_int4_w4a8: bool,
per_act_token_quant: bool, per_act_token_quant: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
) -> mk.FusedMoEModularKernel: ) -> mk.FusedMoEModularKernel:
...@@ -2079,7 +2082,7 @@ def modular_triton_fused_moe( ...@@ -2079,7 +2082,7 @@ def modular_triton_fused_moe(
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8= use_int4_w4a8, use_int4_w4a8=use_int4_w4a8,
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
block_shape=block_shape, block_shape=block_shape,
), ),
......
...@@ -36,9 +36,10 @@ class ActivationMethod(IntEnum): ...@@ -36,9 +36,10 @@ class ActivationMethod(IntEnum):
@cache @cache
def is_rocm_aiter_moe_enabled() -> bool: def is_rocm_aiter_moe_enabled() -> bool:
return current_platform.is_rocm() \ return False
and envs.VLLM_ROCM_USE_AITER_MOE \ # return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER # and envs.VLLM_ROCM_USE_AITER_MOE \
# and envs.VLLM_ROCM_USE_AITER
def rocm_aiter_asm_moe_tkw1_impl( def rocm_aiter_asm_moe_tkw1_impl(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment