Commit 3da42d24 authored by Tri Dao's avatar Tri Dao
Browse files

[GPT] Add option to only return the logit for the last token

parent 311d6606
...@@ -426,20 +426,24 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): ...@@ -426,20 +426,24 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
if self.process_group is not None: if self.process_group is not None:
sync_shared_params(self, self.process_group) sync_shared_params(self, self.process_group)
def forward(self, input_ids, position_ids=None, inference_params=None): def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False):
""" """
inference_params: for generation. Adapted from Megatron-LM (and Apex) inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
last_token_only: whether to return the logit for the last token only,
of shape (batch_size, vocab_size)
""" """
hidden_states = self.transformer(input_ids, position_ids=position_ids, hidden_states = self.transformer(input_ids, position_ids=position_ids,
inference_params=inference_params) inference_params=inference_params)
if last_token_only:
hidden_states = hidden_states[:, -1]
if self.project_out is not None: if self.project_out is not None:
hidden_states = self.project_out(hidden_states) hidden_states = self.project_out(hidden_states)
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
# During inference, we want the full logit for sampling # During inference, we want the full logit for sampling
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None: if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group) lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
lm_logits = rearrange(lm_logits, '(n b) s d -> b s (n d)', b=hidden_states.shape[0]) lm_logits = rearrange(lm_logits, '(n b) ... d -> b ... (n d)', b=hidden_states.shape[0])
CausalLMOutput = namedtuple('CausalLMOutput', ['logits']) CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
return CausalLMOutput(logits=lm_logits) return CausalLMOutput(logits=lm_logits)
......
...@@ -112,7 +112,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, ...@@ -112,7 +112,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
torch.distributed.barrier() torch.distributed.barrier()
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
logits = model(input_ids, inference_params=inference_params).logits[:, -1] logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits
if vocab_size is not None: if vocab_size is not None:
logits = logits[..., :vocab_size] logits = logits[..., :vocab_size]
scores.append(logits if not cg else logits.clone()) scores.append(logits if not cg else logits.clone())
...@@ -127,7 +127,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, ...@@ -127,7 +127,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
dtype=torch.long, device=input_ids.device) dtype=torch.long, device=input_ids.device)
if not cg: if not cg:
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids, logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
inference_params=inference_params).logits[:, -1] inference_params=inference_params, last_token_only=True).logits
else: else:
logits = model._decoding_cache.run(rearrange(next_token, 'b -> b 1'), position_ids, logits = model._decoding_cache.run(rearrange(next_token, 'b -> b 1'), position_ids,
inference_params.sequence_len_offset) inference_params.sequence_len_offset)
...@@ -269,8 +269,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, ...@@ -269,8 +269,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
s.wait_stream(torch.cuda.current_stream()) s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s): with torch.cuda.stream(s):
for _ in range(n_warmups): for _ in range(n_warmups):
logits = model(input_ids, position_ids=position_ids, logits = model(input_ids, position_ids=position_ids, inference_params=inference_params,
inference_params=inference_params).logits[:, -1] last_token_only=True).logits
s.synchronize() s.synchronize()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
# which requires that graph launch and non-captured launch to not overlap (I think, # which requires that graph launch and non-captured launch to not overlap (I think,
...@@ -282,8 +282,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, ...@@ -282,8 +282,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
# To allow capture, automatically sets a side stream as the current stream in the context # To allow capture, automatically sets a side stream as the current stream in the context
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=mempool): with torch.cuda.graph(graph, pool=mempool):
logits = model(input_ids, position_ids=position_ids, logits = model(input_ids, position_ids=position_ids, inference_params=inference_params,
inference_params=inference_params).logits[:, -1] last_token_only=True).logits
def run(new_input_ids, new_position_ids, seqlen): def run(new_input_ids, new_position_ids, seqlen):
inference_params.lengths_per_sample[:] = seqlen inference_params.lengths_per_sample[:] = seqlen
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment