Commit ba2fe7f3 authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Move allocate_inference_cache to within the model

parent 3da42d24
...@@ -335,6 +335,10 @@ class GPTModel(GPTPreTrainedModel): ...@@ -335,6 +335,10 @@ class GPTModel(GPTPreTrainedModel):
if self.process_group is not None: if self.process_group is not None:
sync_shared_params(self, self.process_group) sync_shared_params(self, self.process_group)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return {i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
for i, layer in enumerate(self.layers)}
def forward(self, input_ids, position_ids=None, inference_params=None): def forward(self, input_ids, position_ids=None, inference_params=None):
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
# dimensions so that we can split on it easily, in case of small batch size. # dimensions so that we can split on it easily, in case of small batch size.
...@@ -426,6 +430,10 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): ...@@ -426,6 +430,10 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
if self.process_group is not None: if self.process_group is not None:
sync_shared_params(self, self.process_group) sync_shared_params(self, self.process_group)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.transformer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype,
**kwargs)
def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False): def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False):
""" """
inference_params: for generation. Adapted from Megatron-LM (and Apex) inference_params: for generation. Adapted from Megatron-LM (and Apex)
......
...@@ -105,6 +105,9 @@ class Block(nn.Module): ...@@ -105,6 +105,9 @@ class Block(nn.Module):
for p in self.norm2.parameters(): for p in self.norm2.parameters():
p._shared_params = True p._shared_params = True
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
mixer_subset=None, mixer_kwargs=None): mixer_subset=None, mixer_kwargs=None):
r"""Pass the input through the encoder layer. r"""Pass the input through the encoder layer.
......
...@@ -416,6 +416,22 @@ class MHA(nn.Module): ...@@ -416,6 +416,22 @@ class MHA(nn.Module):
attention_dropout=dropout) attention_dropout=dropout)
self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
dtype = self.out_proj.weight.dtype if dtype is None else dtype
device = self.out_proj.weight.device
if not fused_ft_kernel:
return torch.empty(batch_size, max_seqlen, 2, self.num_heads, self.head_dim,
dtype=dtype, device=device)
else:
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8
assert self.head_dim % packsize == 0
k_cache = torch.empty(batch_size, self.num_heads, self.head_dim // packsize, max_seqlen,
packsize, dtype=dtype, device=device)
v_cache = torch.empty(batch_size, self.num_heads, max_seqlen, self.head_dim,
dtype=dtype, device=device)
return k_cache, v_cache
def _update_kv_cache(self, kv, inference_params): def _update_kv_cache(self, kv, inference_params):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
""" """
......
...@@ -167,8 +167,8 @@ class GenerationMixin: ...@@ -167,8 +167,8 @@ class GenerationMixin:
return output if return_dict_in_generate else output.sequences return output if return_dict_in_generate else output.sequences
def allocate_kv_cache(max_batch_size, max_seqlen, nheads, headdim, layers: Union[int, Sequence], def allocate_inference_cache(max_batch_size, max_seqlen, nheads, headdim, layers: Union[int, Sequence],
device, dtype=torch.float16): device, dtype=torch.float16):
assert dtype in [torch.float16, torch.bfloat16, torch.float32] assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8 packsize = 4 if dtype == torch.float32 else 8
assert headdim % packsize == 0 assert headdim % packsize == 0
...@@ -226,14 +226,17 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p ...@@ -226,14 +226,17 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
headdim = getattr(model.config, 'head_dim', headdim = getattr(model.config, 'head_dim',
model.config.hidden_size // model.config.num_attention_heads) model.config.hidden_size // model.config.num_attention_heads)
kv_cache = allocate_kv_cache( if hasattr(model, 'allocate_inference_cache'):
batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim, inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
model.config.num_hidden_layers, device, dtype else:
) inf_cache = allocate_inference_cache(
batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim,
model.config.num_hidden_layers, device, dtype
)
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
cache.inference_params = InferenceParams( cache.inference_params = InferenceParams(
max_sequence_len=max_seqlen, max_batch_size=batch_size, max_sequence_len=max_seqlen, max_batch_size=batch_size,
sequence_len_offset=seqlen_og, key_value_memory_dict=kv_cache, fused_ft_kernel=True, sequence_len_offset=seqlen_og, key_value_memory_dict=inf_cache, fused_ft_kernel=True,
lengths_per_sample=lengths_per_sample lengths_per_sample=lengths_per_sample
) )
cache.mempool = torch.cuda.graphs.graph_pool_handle() cache.mempool = torch.cuda.graphs.graph_pool_handle()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment