Unverified Commit 1a3fa75f authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[Fix] use `torch.cat` instead of `torch.concat` to prevent entering the...

[Fix] use `torch.cat` instead of `torch.concat` to prevent entering the `Autograd` backends. (#4466)
parent 81f431ed
......@@ -235,7 +235,7 @@ class MiniMaxText01LightningAttention(nn.Module):
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
)
output.append(qkv)
output = torch.concat(output, dim=-2)
output = torch.cat(output, dim=-2)
# reshape
output = rearrange(output, "b h n d -> b n (h d)")
......
......@@ -403,7 +403,7 @@ class MiniMaxText01LightningAttention(nn.Module):
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
)
output.append(qkv)
output = torch.concat(output, dim=-2)
output = torch.cat(output, dim=-2)
# reshape
output = rearrange(output, "b h n d -> b n (h d)")
# normalize
......
......@@ -1244,14 +1244,14 @@ class ScheduleBatch:
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
self.req_pool_indices = torch.concat(
self.req_pool_indices = torch.cat(
[self.req_pool_indices, other.req_pool_indices]
)
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
self.out_cache_loc = None
self.seq_lens_sum += other.seq_lens_sum
if self.output_ids is not None:
self.output_ids = torch.concat([self.output_ids, other.output_ids])
self.output_ids = torch.cat([self.output_ids, other.output_ids])
if self.return_logprob and other.return_logprob:
self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.token_ids_logprobs.extend(other.token_ids_logprobs)
......
......@@ -303,7 +303,7 @@ class HiRadixCache(RadixCache):
value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.concat(value)
value = torch.cat(value)
else:
value = torch.tensor([], dtype=torch.int32)
......
......@@ -172,7 +172,7 @@ class TokenToKVPoolAllocator:
return
if self.is_not_in_free_group:
self.free_slots = torch.concat((self.free_slots, free_index))
self.free_slots = torch.cat((self.free_slots, free_index))
else:
self.free_group.append(free_index)
......@@ -183,7 +183,7 @@ class TokenToKVPoolAllocator:
def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.concat(self.free_group))
self.free(torch.cat(self.free_group))
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
......@@ -739,7 +739,7 @@ class HostKVCache(abc.ABC):
@synchronized
def free(self, indices: torch.Tensor) -> int:
self.mem_state[indices] = MemoryStateInt.IDLE
self.free_slots = torch.concat([self.free_slots, indices])
self.free_slots = torch.cat([self.free_slots, indices])
self.can_use_mem_size += len(indices)
return len(indices)
......
......@@ -272,7 +272,7 @@ class PagedTokenToKVPoolAllocator:
def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.concat(self.free_group))
self.free(torch.cat(self.free_group))
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
......
......@@ -152,7 +152,7 @@ class RadixCache(BasePrefixCache):
value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.concat(value)
value = torch.cat(value)
else:
value = torch.empty((0,), dtype=torch.int32, device=self.device)
return value, last_node
......@@ -317,7 +317,7 @@ class RadixCache(BasePrefixCache):
_dfs_helper(child)
_dfs_helper(self.root_node)
return torch.concat(values)
return torch.cat(values)
##### Internal Helper Functions #####
......
......@@ -383,7 +383,7 @@ class ForwardBatch:
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
mrope_positions_list[i] = mrope_positions
self.mrope_positions = torch.concat(
self.mrope_positions = torch.cat(
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
axis=1,
)
......@@ -449,7 +449,7 @@ def compute_position_kernel(
def compute_position_torch(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
):
positions = torch.concat(
positions = torch.cat(
[
torch.arange(
prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
......
......@@ -1289,7 +1289,7 @@ class MlpProjector(nn.Module):
high_x, low_x = x_or_tuple
high_x = self.high_up_proj(high_x)
low_x = self.low_up_proj(low_x)
x = torch.concat([high_x, low_x], dim=-1)
x = torch.cat([high_x, low_x], dim=-1)
else:
x = x_or_tuple
......
......@@ -828,7 +828,7 @@ class MiniCPMVBaseModel(nn.Module):
)
if isinstance(image_embeds, list):
image_embeds = torch.concat(image_embeds)
image_embeds = torch.cat(image_embeds)
return MiniCPMVImageEmbeddingInputs(
image_bounds=image_bounds,
......
......@@ -306,7 +306,7 @@ class SamplingBatchInfo:
]:
self_val = getattr(self, item, None)
other_val = getattr(other, item, None)
setattr(self, item, torch.concat([self_val, other_val]))
setattr(self, item, torch.cat([self_val, other_val]))
self.is_all_greedy |= other.is_all_greedy
self.need_min_p_sampling |= other.need_min_p_sampling
......@@ -59,7 +59,7 @@ class EagleDraftInput:
pt = 0
for i, extend_len in enumerate(batch.extend_lens):
input_ids = batch.input_ids[pt : pt + extend_len]
batch.input_ids[pt : pt + extend_len] = torch.concat(
batch.input_ids[pt : pt + extend_len] = torch.cat(
(input_ids[1:], self.verified_id[i].reshape(1))
)
pt += extend_len
......
......@@ -148,7 +148,7 @@ def lightning_attention_decode_naive(q, k, v, past_kv, slope):
kv.to(torch.float32),
)
output.append(qkv)
output = torch.concat(output, dim=-2)
output = torch.cat(output, dim=-2)
return output.to(original_dtype), kv
......
......@@ -24,7 +24,7 @@ def naive_lightning_attention_decode(q, k, v, past_kv, slope):
kv.to(torch.float32),
)
output.append(qkv)
output = torch.concat(output, dim=-2)
output = torch.cat(output, dim=-2)
return output.to(original_dtype), kv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment