"Deepspeed/MoQ/vscode:/vscode.git/clone" did not exist on "31258341b7b81dec027284df9a516e55b79491e9"
Unverified Commit dd70437a authored by Icey's avatar Icey Committed by GitHub
Browse files

Remove cuda hard-code in compute_causal_conv1d_metadata (#25555)


Signed-off-by: default avatarIcey <1790571317@qq.com>
parent 99b3a504
......@@ -947,6 +947,7 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
nums_dict = {} # type: ignore
batch_ptr = None
token_chunk_offset_ptr = None
device = query_start_loc_p.device
for BLOCK_M in [8]: # cover all BLOCK_M values
nums = -(-seqlens // BLOCK_M)
nums_dict[BLOCK_M] = {}
......@@ -968,11 +969,11 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
device=device)
token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
device=device)
else:
if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment