Commit a983ea53 authored by 王敏's avatar 王敏
Browse files

[feat]优化高吞吐模式num_sms

parent 3833018c
......@@ -173,6 +173,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
if self.internode:
num_rdma_bytes = int(1e9/2) #1024 * 1024 * 1024
num_qps_per_rank = 30 #self.num_sms // 2
self.num_sms = 30
# import deep_ep
# num_nvl_bytes, num_rdma_bytes = 0, 0
......@@ -184,6 +185,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
else:
num_rdma_bytes = 0
num_qps_per_rank = 1
self.num_sms = 60
assert num_rdma_bytes is not None
assert num_qps_per_rank is not None
......
......@@ -547,41 +547,6 @@ def _fwd_kernel_ep_scatter_2(
SCALE_HIDDEN_SIZE: tl.constexpr,
SCALE_HIDDEN_SIZE_PAD: tl.constexpr,
):
# start_token_id = tl.program_id(0)
# grid_num = tl.num_programs(0)
# offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
# mask = offset_in < HIDDEN_SIZE
# offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
# mask_s = offset_in_s < SCALE_HIDDEN_SIZE
# for token_id in range(start_token_id, total_token_num, grid_num):
# to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
# to_copy_s = tl.load(
# recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
# )
# for topk_index in tl.range(0, topk_num, 1, num_stages=4):
# expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
# if HAS_EXPERT_MAP:
# expert_id = apply_expert_map(expert_id, expert_map)
# if expert_id >= 0:
# dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
# tl.store(
# output_index + token_id * output_index_stride0 + topk_index,
# dest_token_index,
# )
# output_tensor_ptr = (
# output_tensor + dest_token_index * output_tensor_stride0
# )
# output_tensor_scale_ptr = (
# output_tensor_scale + dest_token_index * output_tensor_scale_stride0
# )
# tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
# tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s)
start_token_id = tl.program_id(0)
grid_num = tl.num_programs(0)
......@@ -720,43 +685,6 @@ def _fwd_kernel_ep_gather(
HAS_EXPERT_MAP: tl.constexpr,
BLOCK_D: tl.constexpr,
):
# cur_block = tl.program_id(0)
# start_cur_token = tl.program_id(1)
# grid_num = tl.num_programs(1)
# for cur_token in range(start_cur_token, total_token_num, grid_num):
# off_d = tl.arange(0, BLOCK_D)
# accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
# for topk_index in range(0, topk_num):
# expert_id = tl.load(
# recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
# )
# if HAS_EXPERT_MAP:
# expert_id = apply_expert_map(expert_id, expert_map)
# if expert_id >= 0:
# source_token_index = tl.load(
# input_index + cur_token * input_index_stride0 + topk_index
# )
# acc_weight = tl.load(
# recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
# )
# tmp = tl.load(
# input_tensor
# + source_token_index * input_tensor_stride0
# + cur_block * BLOCK_D
# + off_d
# )
# accumulator += tmp.to(tl.float32) * acc_weight
# tl.store(
# output_tensor
# + cur_token * output_tensor_stride0
# + cur_block * BLOCK_D
# + off_d,
# accumulator.to(output_tensor.dtype.element_ty),
# )
cur_block_int32 = tl.program_id(0)
cur_block = cur_block_int32.to(tl.int64)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment