Unverified Commit c7ae474a authored by yigex's avatar yigex Committed by GitHub
Browse files

[Feature, Hardware] Enable DeepseekV3 on AMD GPUs (#2601)


Co-authored-by: default avatarroot <root@banff-cyxtera-s83-5.amd.com>
Co-authored-by: default avatarHAI <hixiao@gmail.com>
Co-authored-by: default avatarBruce Xue <yigex@xilinx.com>
Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent bdf946bf
...@@ -406,6 +406,10 @@ def _decode_grouped_att_m_fwd( ...@@ -406,6 +406,10 @@ def _decode_grouped_att_m_fwd(
Lk = k_buffer.shape[-1] Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1] Lv = v_buffer.shape[-1]
# [TODO] work around shmem limit on MI3xx
if is_hip_ and Lk >= 576:
BLOCK = 16
if Lk == 576: if Lk == 576:
BLOCK_DMODEL = 512 BLOCK_DMODEL = 512
BLOCK_DPE = 64 BLOCK_DPE = 64
......
...@@ -477,9 +477,9 @@ def invoke_fused_moe_kernel( ...@@ -477,9 +477,9 @@ def invoke_fused_moe_kernel(
padded_size = 0 padded_size = 0
if use_fp8_w8a8: if use_fp8_w8a8:
padded_size = padding_size
assert B_scale is not None assert B_scale is not None
if block_shape is None: if block_shape is None:
padded_size = padding_size
A, A_scale = ops.scaled_fp8_quant(A, A_scale) A, A_scale = ops.scaled_fp8_quant(A, A_scale)
else: else:
assert len(block_shape) == 2 assert len(block_shape) == 2
...@@ -614,7 +614,7 @@ def get_default_config( ...@@ -614,7 +614,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32, "GROUP_SIZE_M": 32,
"num_warps": 8, "num_warps": 8,
"num_stages": 4, "num_stages": 2 if is_hip_flag else 4,
} }
if M <= E: if M <= E:
config = { config = {
...@@ -623,7 +623,7 @@ def get_default_config( ...@@ -623,7 +623,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 4, "num_stages": 2 if is_hip_flag else 4,
} }
else: else:
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1] # Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
...@@ -633,7 +633,7 @@ def get_default_config( ...@@ -633,7 +633,7 @@ def get_default_config(
"BLOCK_SIZE_K": block_shape[1], "BLOCK_SIZE_K": block_shape[1],
"GROUP_SIZE_M": 32, "GROUP_SIZE_M": 32,
"num_warps": 4, "num_warps": 4,
"num_stages": 3, "num_stages": 2 if is_hip_flag else 3,
} }
else: else:
config = { config = {
...@@ -878,7 +878,7 @@ def fused_experts_impl( ...@@ -878,7 +878,7 @@ def fused_experts_impl(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
): ):
padded_size = padding_size padded_size = padding_size
if not use_fp8_w8a8: if not use_fp8_w8a8 or block_shape is not None:
padded_size = 0 padded_size = 0
# Check constraints. # Check constraints.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment