Unverified Commit 5f6756b0 authored by Morpheus Guo's avatar Morpheus Guo Committed by GitHub
Browse files

[BugFix] fix pre_reorder_triton_kernel default int32 issue (#7814)

parent 98aa836b
...@@ -236,7 +236,8 @@ def pre_reorder_triton_kernel( ...@@ -236,7 +236,8 @@ def pre_reorder_triton_kernel(
): ):
OutDtype = gateup_input_ptr.dtype.element_ty OutDtype = gateup_input_ptr.dtype.element_ty
src_idx = tl.program_id(0) src_idx_int32 = tl.program_id(0)
src_idx = src_idx_int32.to(tl.int64)
src2dst_ptr = src2dst_ptr + src_idx * topk src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk topk_ids_ptr = topk_ids_ptr + src_idx * topk
src_ptr = input_ptr + src_idx * hidden_size src_ptr = input_ptr + src_idx * hidden_size
...@@ -255,7 +256,8 @@ def pre_reorder_triton_kernel( ...@@ -255,7 +256,8 @@ def pre_reorder_triton_kernel(
else: else:
scale = 1.0 scale = 1.0
dst_idx = tl.load(src2dst_ptr + idx) dst_idx_int32 = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx_int32.to(tl.int64)
dst_ptr = gateup_input_ptr + dst_idx * hidden_size dst_ptr = gateup_input_ptr + dst_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + vec offset = start_offset + vec
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment