Unverified Commit 12d68183 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Tiny fix ep_gather behavior different in CI (#11130)

parent b65db028
...@@ -1104,10 +1104,10 @@ def ep_gather( ...@@ -1104,10 +1104,10 @@ def ep_gather(
input_index: torch.Tensor, input_index: torch.Tensor,
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
): ):
BLOCK_D = 1024 if not is_in_ci() else 128 # block size of quantization
num_warps = 2 num_warps = 2
num_tokens = output_tensor.shape[0] num_tokens = output_tensor.shape[0]
hidden_size = input_tensor.shape[1] hidden_size = input_tensor.shape[1]
BLOCK_D = 128 if hidden_size % 1024 != 0 else 1024 # block size of quantization
assert hidden_size % BLOCK_D == 0 assert hidden_size % BLOCK_D == 0
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024)) grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
_fwd_kernel_ep_gather[grid]( _fwd_kernel_ep_gather[grid](
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment