"...git@developer.sourcefind.cn:jerrrrry/infinicore.git" did not exist on "3d3a277fbe7ec2fa9bd5642b0bd375bfccb05d18"
Unverified Commit b172bad8 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix dtype for KV inference cache (#319)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6bd35bf9
......@@ -1108,14 +1108,14 @@ class MultiHeadAttention(torch.nn.Module):
def _allocate_memory(
self, inference_max_sequence_len: int, batch_size: int
self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
) -> torch.Tensor:
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
dtype=dtype,
device=torch.cuda.current_device(),
)
......@@ -1154,10 +1154,10 @@ class MultiHeadAttention(torch.nn.Module):
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size
inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size
inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment