Commit b4859900 authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Add timing option

parent 0938298e
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
from typing import Optional
import time
from dataclasses import dataclass, field
from collections import namedtuple
import torch
from torch import Tensor
from torch.profiler import profile, record_function, ProfilerActivity
from einops import rearrange
......@@ -65,7 +69,8 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, fused_ft_kernel=True):
def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
fused_ft_kernel=False, cg=False, timing=False):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
......@@ -89,17 +94,31 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, fu
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
sequences = [next_token]
inference_params.sequence_len_offset = seqlen_og
if cg:
assert fused_ft_kernel
run, cg_cache = capture_cg(model, inference_params, batch_size, seqlen_og, max_length)
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
if timing:
start = time.time()
while True:
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
dtype=torch.long, device=input_ids.device)
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
dtype=torch.long, device=input_ids.device)
if not cg:
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
else:
logits = run(rearrange(next_token, 'b -> b 1'), position_ids,
inference_params.sequence_len_offset)
scores.append(logits)
next_token = sample(logits, top_k=top_k, temperature=temperature)
sequences.append(next_token)
inference_params.sequence_len_offset += 1
if inference_params.sequence_len_offset >= max_length - 1:
break
if timing:
print(f'Decoding time: {time.time() - start}')
# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=50))
# prof.export_chrome_trace("gpt2s_generation.json")
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls(
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
......@@ -116,3 +135,83 @@ class GenerationMixin:
if not output_scores:
output.scores = None
return output if return_dict_in_generate else output.sequences
CgKey = namedtuple('CgKey', ['batch_size', 'seqlen_type', 'max_length'])
CgVal = namedtuple('CgVal', ['graph', 'input_ids', 'position_ids', 'lengths', 'logits'])
def seqlen_to_seqlen_type(seqlen: int) -> int:
"""Convert sequence length to a seqlen_type.
This is used to determine which cuda graph to use.
Arguments:
seqlen: int
"""
return 0 if seqlen < 32 else (1 if seqlen < 2048 else 2)
def seqlen_type_to_seqlen(seqlen_type: int) -> int:
assert seqlen_type in [0, 1, 2]
return 1 if seqlen_type == 0 else (32 if seqlen_type == 1 else 2048)
def capture_cg(model, inference_params, batch_size, seqlen_og, max_length, copy_output=False):
"""Build a cache of cuda graphs for decoding.
Arguments:
model: a GPTLMHeadModel
batch_size: int
seqlen_og: int. Length of the prompt.
max_length: int
TODO: how do we deal with the k_cache and v_cache memory? I think the CUDA graph also
has to own the k_cache and v_cache?
Here we assume that the model already has inference_params from the prompt processing.
"""
assert max_length > seqlen_og
cg_cache: dict[CgKey, CgVal] = {}
device = next(iter(model.parameters())).device
sequence_length_offset_og = inference_params.sequence_len_offset
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
inference_params.lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32,
device=device)
memory_pool = None
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_length) + 1):
seqlen = max(seqlen_og, seqlen_type_to_seqlen(s_type))
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
inference_params.lengths_per_sample[:] = seqlen
inference_params.sequence_len_offset = seqlen
g = torch.cuda.CUDAGraph()
# Warmup before capture
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(2):
logits = model(input_ids, position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
torch.cuda.current_stream().wait_stream(s)
# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
with torch.cuda.graph(g, pool=memory_pool):
logits = model(input_ids, position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
if memory_pool is None:
memory_pool = g.pool()
cg_cache[CgKey(batch_size, s_type, max_length)] = CgVal(
g, input_ids, position_ids, inference_params.lengths_per_sample, logits
)
def run(new_input_ids, new_position_ids, seqlen):
cg_val = cg_cache[CgKey(batch_size, seqlen_to_seqlen_type(seqlen), max_length)]
inference_params.lengths_per_sample = cg_val.lengths
inference_params.lengths_per_sample[:] = seqlen
cg_val.input_ids.copy_(new_input_ids)
cg_val.position_ids.copy_(new_position_ids)
cg_val.graph.replay()
output = cg_val.logits
return output.clone() if copy_output else output
inference_params.sequence_len_offset = sequence_length_offset_og
return run, cg_cache
......@@ -54,8 +54,8 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.cuda()
max_length = 30
# input_ids = torch.randint(0, 100, (1, 512), dtype=torch.long, device='cuda')
# max_length = 512 + 50
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# Slow generation for reference
sequences = []
......@@ -73,7 +73,13 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
out = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True)
return_dict_in_generate=True, output_scores=True, timing=True)
print(out.sequences)
if fused_ft_kernel:
out_cg = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=fused_ft_kernel, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True)
print(out_cg.sequences)
if not rotary:
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment