Commit f8aea6ea authored by Tri Dao's avatar Tri Dao
Browse files

[GPT] Generalize last_token_only arg to num_last_tokens

parent 7a3bd55f
...@@ -621,18 +621,17 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): ...@@ -621,18 +621,17 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
batch_size, max_seqlen, dtype=dtype, **kwargs batch_size, max_seqlen, dtype=dtype, **kwargs
) )
def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False): def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
""" """
inference_params: for generation. Adapted from Megatron-LM (and Apex) inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
last_token_only: whether to return the logit for the last token only, num_last_tokens: if > 0, only return the logits for the last n tokens
of shape (batch_size, vocab_size)
""" """
hidden_states = self.transformer( hidden_states = self.transformer(
input_ids, position_ids=position_ids, inference_params=inference_params input_ids, position_ids=position_ids, inference_params=inference_params
) )
if last_token_only: if num_last_tokens > 0:
hidden_states = hidden_states[:, -1] hidden_states = hidden_states[:, -num_last_tokens:]
if self.project_out is not None: if self.project_out is not None:
hidden_states = self.project_out(hidden_states) hidden_states = self.project_out(hidden_states)
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
......
...@@ -27,11 +27,19 @@ class InferenceParams: ...@@ -27,11 +27,19 @@ class InferenceParams:
lengths_per_sample: Optional[Tensor] = None lengths_per_sample: Optional[Tensor] = None
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf."""
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(indices_to_remove, float("-Inf"))
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def modify_logits_for_top_p_filtering(logits, top_p): def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf.""" """Set the logits for none top-p values to -inf."""
if top_p <= 0.0: if top_p <= 0.0 or top_p >= 1.0:
return return
# First sort and calculate cumulative sum of probabilities. # First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=False) sorted_logits, sorted_indices = torch.sort(logits, descending=False)
...@@ -58,6 +66,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0): ...@@ -58,6 +66,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
if top_k > 0: if top_k > 0:
top_k = min(top_k, logits.size(-1)) # Safety check top_k = min(top_k, logits.size(-1)) # Safety check
logits_top, indices = torch.topk(logits, top_k, dim=-1) logits_top, indices = torch.topk(logits, top_k, dim=-1)
if temperature != 1.0:
logits_top /= temperature logits_top /= temperature
modify_logits_for_top_p_filtering(logits_top, top_p) modify_logits_for_top_p_filtering(logits_top, top_p)
return indices[ return indices[
...@@ -65,7 +74,8 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0): ...@@ -65,7 +74,8 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
] ]
else: else:
logits_top = logits / temperature # Clone so that when we modify for top_p we don't change the original logits
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
modify_logits_for_top_p_filtering(logits_top, top_p) modify_logits_for_top_p_filtering(logits_top, top_p)
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze( return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
dim=-1 dim=-1
...@@ -131,8 +141,8 @@ def decode( ...@@ -131,8 +141,8 @@ def decode(
input_ids, input_ids,
position_ids=position_ids, position_ids=position_ids,
inference_params=inference_params, inference_params=inference_params,
last_token_only=True, num_last_tokens=1,
).logits ).logits.squeeze(dim=1)
else: else:
return model._decoding_cache.run( return model._decoding_cache.run(
input_ids, position_ids, inference_params.sequence_len_offset input_ids, position_ids, inference_params.sequence_len_offset
...@@ -149,7 +159,9 @@ def decode( ...@@ -149,7 +159,9 @@ def decode(
torch.distributed.barrier() torch.distributed.barrier()
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits logits = model(
input_ids, inference_params=inference_params, num_last_tokens=1
).logits.squeeze(dim=1)
logits = logits_postprocess_fn(logits) logits = logits_postprocess_fn(logits)
scores.append(logits if not cg else logits.clone()) scores.append(logits if not cg else logits.clone())
if teacher_outputs is None or teacher_output_len <= seqlen_og: if teacher_outputs is None or teacher_output_len <= seqlen_og:
...@@ -165,9 +177,9 @@ def decode( ...@@ -165,9 +177,9 @@ def decode(
dtype=torch.long, dtype=torch.long,
device=input_ids.device, device=input_ids.device,
) )
logits = logits_postprocess_fn(logits_forward_fn( logits = logits_postprocess_fn(
rearrange(next_token, "b -> b 1"), position_ids, inference_params logits_forward_fn(rearrange(next_token, "b -> b 1"), position_ids, inference_params)
)) )
scores.append(logits) scores.append(logits)
if ( if (
teacher_outputs is None teacher_outputs is None
...@@ -357,7 +369,7 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, ...@@ -357,7 +369,7 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
input_ids, input_ids,
position_ids=position_ids, position_ids=position_ids,
inference_params=inference_params, inference_params=inference_params,
last_token_only=True, num_last_tokens=1,
).logits ).logits
s.synchronize() s.synchronize()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
...@@ -374,8 +386,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, ...@@ -374,8 +386,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
input_ids, input_ids,
position_ids=position_ids, position_ids=position_ids,
inference_params=inference_params, inference_params=inference_params,
last_token_only=True, num_last_tokens=1,
).logits ).logits.squeeze(dim=1)
def run(new_input_ids, new_position_ids, seqlen): def run(new_input_ids, new_position_ids, seqlen):
inference_params.lengths_per_sample[:] = seqlen inference_params.lengths_per_sample[:] = seqlen
......
...@@ -355,8 +355,6 @@ def test_gpt2_multiple_token_generation(model_name, optimized): ...@@ -355,8 +355,6 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
# fused_ft_kernel currently doesn't work with multiple tokens at a time # fused_ft_kernel currently doesn't work with multiple tokens at a time
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval() model.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment