Commit b4859900 authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Add timing option

parent 0938298e
# Copyright (c) 2022, Tri Dao. # Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31 # Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
from typing import Optional from typing import Optional
import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from collections import namedtuple
import torch import torch
from torch import Tensor from torch import Tensor
from torch.profiler import profile, record_function, ProfilerActivity
from einops import rearrange from einops import rearrange
...@@ -65,7 +69,8 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0): ...@@ -65,7 +69,8 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1) return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, fused_ft_kernel=True): def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
fused_ft_kernel=False, cg=False, timing=False):
"""Decoding, either greedy or with top-k or top-p sampling. """Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling). If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
...@@ -89,17 +94,31 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, fu ...@@ -89,17 +94,31 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, fu
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
sequences = [next_token] sequences = [next_token]
inference_params.sequence_len_offset = seqlen_og inference_params.sequence_len_offset = seqlen_og
if cg:
assert fused_ft_kernel
run, cg_cache = capture_cg(model, inference_params, batch_size, seqlen_og, max_length)
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
if timing:
start = time.time()
while True: while True:
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset, position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
dtype=torch.long, device=input_ids.device) dtype=torch.long, device=input_ids.device)
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids, if not cg:
inference_params=inference_params).logits[:, -1] logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
else:
logits = run(rearrange(next_token, 'b -> b 1'), position_ids,
inference_params.sequence_len_offset)
scores.append(logits) scores.append(logits)
next_token = sample(logits, top_k=top_k, temperature=temperature) next_token = sample(logits, top_k=top_k, temperature=temperature)
sequences.append(next_token) sequences.append(next_token)
inference_params.sequence_len_offset += 1 inference_params.sequence_len_offset += 1
if inference_params.sequence_len_offset >= max_length - 1: if inference_params.sequence_len_offset >= max_length - 1:
break break
if timing:
print(f'Decoding time: {time.time() - start}')
# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=50))
# prof.export_chrome_trace("gpt2s_generation.json")
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls( return output_cls(
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
...@@ -116,3 +135,83 @@ class GenerationMixin: ...@@ -116,3 +135,83 @@ class GenerationMixin:
if not output_scores: if not output_scores:
output.scores = None output.scores = None
return output if return_dict_in_generate else output.sequences return output if return_dict_in_generate else output.sequences
CgKey = namedtuple('CgKey', ['batch_size', 'seqlen_type', 'max_length'])
CgVal = namedtuple('CgVal', ['graph', 'input_ids', 'position_ids', 'lengths', 'logits'])
def seqlen_to_seqlen_type(seqlen: int) -> int:
"""Convert sequence length to a seqlen_type.
This is used to determine which cuda graph to use.
Arguments:
seqlen: int
"""
return 0 if seqlen < 32 else (1 if seqlen < 2048 else 2)
def seqlen_type_to_seqlen(seqlen_type: int) -> int:
assert seqlen_type in [0, 1, 2]
return 1 if seqlen_type == 0 else (32 if seqlen_type == 1 else 2048)
def capture_cg(model, inference_params, batch_size, seqlen_og, max_length, copy_output=False):
"""Build a cache of cuda graphs for decoding.
Arguments:
model: a GPTLMHeadModel
batch_size: int
seqlen_og: int. Length of the prompt.
max_length: int
TODO: how do we deal with the k_cache and v_cache memory? I think the CUDA graph also
has to own the k_cache and v_cache?
Here we assume that the model already has inference_params from the prompt processing.
"""
assert max_length > seqlen_og
cg_cache: dict[CgKey, CgVal] = {}
device = next(iter(model.parameters())).device
sequence_length_offset_og = inference_params.sequence_len_offset
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
inference_params.lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32,
device=device)
memory_pool = None
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_length) + 1):
seqlen = max(seqlen_og, seqlen_type_to_seqlen(s_type))
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
inference_params.lengths_per_sample[:] = seqlen
inference_params.sequence_len_offset = seqlen
g = torch.cuda.CUDAGraph()
# Warmup before capture
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(2):
logits = model(input_ids, position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
torch.cuda.current_stream().wait_stream(s)
# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
with torch.cuda.graph(g, pool=memory_pool):
logits = model(input_ids, position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
if memory_pool is None:
memory_pool = g.pool()
cg_cache[CgKey(batch_size, s_type, max_length)] = CgVal(
g, input_ids, position_ids, inference_params.lengths_per_sample, logits
)
def run(new_input_ids, new_position_ids, seqlen):
cg_val = cg_cache[CgKey(batch_size, seqlen_to_seqlen_type(seqlen), max_length)]
inference_params.lengths_per_sample = cg_val.lengths
inference_params.lengths_per_sample[:] = seqlen
cg_val.input_ids.copy_(new_input_ids)
cg_val.position_ids.copy_(new_position_ids)
cg_val.graph.replay()
output = cg_val.logits
return output.clone() if copy_output else output
inference_params.sequence_len_offset = sequence_length_offset_og
return run, cg_cache
...@@ -54,8 +54,8 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel): ...@@ -54,8 +54,8 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.cuda() input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.cuda()
max_length = 30 max_length = 30
# input_ids = torch.randint(0, 100, (1, 512), dtype=torch.long, device='cuda') # input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# max_length = 512 + 50 # max_length = input_ids.shape[1] + 40
# Slow generation for reference # Slow generation for reference
sequences = [] sequences = []
...@@ -73,7 +73,13 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel): ...@@ -73,7 +73,13 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
out = model.generate(input_ids=input_ids, max_length=max_length, out = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=fused_ft_kernel, fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True) return_dict_in_generate=True, output_scores=True, timing=True)
print(out.sequences)
if fused_ft_kernel:
out_cg = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=fused_ft_kernel, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True)
print(out_cg.sequences)
if not rotary: if not rotary:
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment