Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ae856f3a
Commit
ae856f3a
authored
Mar 28, 2024
by
Woosuk Kwon
Browse files
Remove unnecessary files
parent
6ac8e63a
Changes
45
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
0 additions
and
1226 deletions
+0
-1226
vllm_flash_attn/utils/__init__.py
vllm_flash_attn/utils/__init__.py
+0
-0
vllm_flash_attn/utils/benchmark.py
vllm_flash_attn/utils/benchmark.py
+0
-268
vllm_flash_attn/utils/distributed.py
vllm_flash_attn/utils/distributed.py
+0
-144
vllm_flash_attn/utils/generation.py
vllm_flash_attn/utils/generation.py
+0
-735
vllm_flash_attn/utils/pretrained.py
vllm_flash_attn/utils/pretrained.py
+0
-79
No files found.
vllm_flash_attn/utils/__init__.py
deleted
100644 → 0
View file @
6ac8e63a
vllm_flash_attn/utils/benchmark.py
deleted
100644 → 0
View file @
6ac8e63a
# Copyright (c) 2023, Tri Dao.
""" Useful functions for writing test code. """
import
torch
import
torch.utils.benchmark
as
benchmark
def
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
if
verbose
:
print
(
desc
,
"- Forward pass"
)
def
amp_wrapper
(
*
inputs
,
**
kwinputs
):
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
fn
(
*
inputs
,
**
kwinputs
)
t
=
benchmark
.
Timer
(
stmt
=
"fn_amp(*inputs, **kwinputs)"
,
globals
=
{
"fn_amp"
:
amp_wrapper
,
"inputs"
:
inputs
,
"kwinputs"
:
kwinputs
},
num_threads
=
torch
.
get_num_threads
(),
)
m
=
t
.
timeit
(
repeats
)
if
verbose
:
print
(
m
)
return
t
,
m
def
benchmark_backward
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the backward pass of an arbitrary function."""
if
verbose
:
print
(
desc
,
"- Backward pass"
)
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
if
grad
is
None
:
grad
=
torch
.
randn_like
(
y
)
else
:
if
grad
.
shape
!=
y
.
shape
:
raise
RuntimeError
(
"Grad shape does not match output shape"
)
def
f
(
*
inputs
,
y
,
grad
):
# Set .grad to None to avoid extra operation of gradient accumulation
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
y
.
backward
(
grad
,
retain_graph
=
True
)
t
=
benchmark
.
Timer
(
stmt
=
"f(*inputs, y=y, grad=grad)"
,
globals
=
{
"f"
:
f
,
"inputs"
:
inputs
,
"y"
:
y
,
"grad"
:
grad
},
num_threads
=
torch
.
get_num_threads
(),
)
m
=
t
.
timeit
(
repeats
)
if
verbose
:
print
(
m
)
return
t
,
m
def
benchmark_combined
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
if
verbose
:
print
(
desc
,
"- Forward + Backward pass"
)
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
if
grad
is
None
:
grad
=
torch
.
randn_like
(
y
)
else
:
if
grad
.
shape
!=
y
.
shape
:
raise
RuntimeError
(
"Grad shape does not match output shape"
)
def
f
(
grad
,
*
inputs
,
**
kwinputs
):
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
y
.
backward
(
grad
,
retain_graph
=
True
)
t
=
benchmark
.
Timer
(
stmt
=
"f(grad, *inputs, **kwinputs)"
,
globals
=
{
"f"
:
f
,
"fn"
:
fn
,
"inputs"
:
inputs
,
"grad"
:
grad
,
"kwinputs"
:
kwinputs
},
num_threads
=
torch
.
get_num_threads
(),
)
m
=
t
.
timeit
(
repeats
)
if
verbose
:
print
(
m
)
return
t
,
m
def
benchmark_fwd_bwd
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return
(
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
benchmark_backward
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
)
def
benchmark_all
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return
(
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
benchmark_backward
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
benchmark_combined
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
)
def
pytorch_profiler
(
fn
,
*
inputs
,
trace_filename
=
None
,
backward
=
False
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
cpu
=
False
,
verbose
=
True
,
**
kwinputs
,
):
"""Wrap benchmark functions in Pytorch profiler to see CUDA information."""
if
backward
:
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
out
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
out
)
is
tuple
:
out
=
out
[
0
]
g
=
torch
.
randn_like
(
out
)
for
_
in
range
(
30
):
# Warm up
if
backward
:
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
out
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
out
)
is
tuple
:
out
=
out
[
0
]
# Backward should be done outside autocast
if
backward
:
out
.
backward
(
g
,
retain_graph
=
True
)
activities
=
([
torch
.
profiler
.
ProfilerActivity
.
CPU
]
if
cpu
else
[])
+
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
]
with
torch
.
profiler
.
profile
(
activities
=
activities
,
record_shapes
=
True
,
# profile_memory=True,
with_stack
=
True
,
)
as
prof
:
if
backward
:
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
out
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
out
)
is
tuple
:
out
=
out
[
0
]
if
backward
:
out
.
backward
(
g
,
retain_graph
=
True
)
if
verbose
:
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
print
(
prof
.
key_averages
().
table
(
row_limit
=
50
))
if
trace_filename
is
not
None
:
prof
.
export_chrome_trace
(
trace_filename
)
def
benchmark_memory
(
fn
,
*
inputs
,
desc
=
""
,
verbose
=
True
,
**
kwinputs
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_peak_memory_stats
()
torch
.
cuda
.
synchronize
()
fn
(
*
inputs
,
**
kwinputs
)
torch
.
cuda
.
synchronize
()
mem
=
torch
.
cuda
.
max_memory_allocated
()
/
((
2
**
20
)
*
1000
)
if
verbose
:
print
(
f
"
{
desc
}
max memory:
{
mem
}
GB"
)
torch
.
cuda
.
empty_cache
()
return
mem
vllm_flash_attn/utils/distributed.py
deleted
100644 → 0
View file @
6ac8e63a
from
typing
import
Optional
import
torch
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward compatibility with
# older PyTorch.
if
"all_gather_into_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
all_gather_into_tensor
=
torch
.
distributed
.
_all_gather_base
if
"reduce_scatter_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
reduce_scatter_tensor
=
torch
.
distributed
.
_reduce_scatter_base
# Raw operation, does not support autograd, but does support async
def
all_gather_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
output
=
torch
.
empty
(
world_size
*
input_
.
shape
[
0
],
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
handle
=
torch
.
distributed
.
all_gather_into_tensor
(
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
return
output
,
handle
# Raw operation, does not support autograd, but does support async
def
reduce_scatter_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
assert
input_
.
shape
[
0
]
%
world_size
==
0
output
=
torch
.
empty
(
input_
.
shape
[
0
]
//
world_size
,
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
handle
=
torch
.
distributed
.
reduce_scatter_tensor
(
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
return
output
,
handle
# Raw operation, does not support autograd, but does support async
def
all_reduce_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
input_
=
input_
.
contiguous
()
handle
=
torch
.
distributed
.
all_reduce
(
input_
,
group
=
process_group
,
async_op
=
async_op
)
return
input_
,
handle
class
AllGatherFunc
(
torch
.
autograd
.
Function
):
"""Gather the input from sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
all_gather_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
grad_input
,
_
=
reduce_scatter_raw
(
grad_output
,
ctx
.
process_group
)
return
grad_input
,
None
# Supports autograd, but does not support async
all_gather
=
AllGatherFunc
.
apply
class
ReduceScatterFunc
(
torch
.
autograd
.
Function
):
"""Reduce scatter the input from the sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
reduce_scatter_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
grad_input
,
_
=
all_gather_raw
(
grad_output
,
ctx
.
process_group
)
return
grad_input
,
None
# Supports autograd, but does not support async
reduce_scatter
=
ReduceScatterFunc
.
apply
class
AllReduceFunc
(
torch
.
autograd
.
Function
):
"""Gather the input from sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
all_reduce_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
return
grad_output
,
None
# Supports autograd, but does not support async
all_reduce
=
AllReduceFunc
.
apply
def
sync_shared_params
(
model
:
torch
.
nn
.
Module
,
process_group
:
ProcessGroup
):
# We want to iterate over parameters with _shared_params=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
pamams_shared
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
if
getattr
(
p
,
"_shared_params"
,
False
)
}
for
_
,
p
in
sorted
(
pamams_shared
.
items
()):
with
torch
.
no_grad
():
# Broadcast needs src to be global rank, not group rank
torch
.
distributed
.
broadcast
(
p
,
src
=
torch
.
distributed
.
get_global_rank
(
process_group
,
0
),
group
=
process_group
)
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
def
allreduce_sequence_parallel_grad
(
model
:
torch
.
nn
.
Module
,
process_group
:
ProcessGroup
):
# We want to iterate over parameters with _sequence_parallel=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
if
getattr
(
p
,
"_sequence_parallel"
,
False
)
}
grads
=
[
p
.
grad
for
_
,
p
in
sorted
(
params_seqparallel
.
items
())]
if
grads
:
with
torch
.
no_grad
():
coalesced
=
torch
.
_utils
.
_flatten_dense_tensors
(
grads
)
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
process_group
)
for
buf
,
synced
in
zip
(
grads
,
torch
.
_utils
.
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
def
get_dim_for_local_rank
(
dim
:
int
,
world_size
:
int
,
local_rank
:
int
,
multiple_of
:
int
=
1
)
->
int
:
"""Get the dim for the local rank derived from splitting dim on world_size processes.
The split may not be even across the world_size processes.
"""
multiple
=
dim
//
multiple_of
div
=
multiple
//
world_size
mod
=
multiple
%
world_size
local_multiple
=
div
+
int
(
local_rank
<
mod
)
return
local_multiple
*
multiple_of
vllm_flash_attn/utils/generation.py
deleted
100644 → 0
View file @
6ac8e63a
# Copyright (c) 2023, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
import
gc
import
time
from
collections
import
namedtuple
from
dataclasses
import
dataclass
,
field
from
functools
import
partial
from
typing
import
Callable
,
Optional
,
Sequence
,
Union
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
torch
import
Tensor
from
torch.profiler
import
ProfilerActivity
,
profile
,
record_function
from
transformers.generation
import
GreedySearchDecoderOnlyOutput
,
SampleDecoderOnlyOutput
@
dataclass
class
InferenceParams
:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_seqlen
:
int
max_batch_size
:
int
seqlen_offset
:
int
=
0
batch_size_offset
:
int
=
0
key_value_memory_dict
:
dict
=
field
(
default_factory
=
dict
)
lengths_per_sample
:
Optional
[
Tensor
]
=
None
def
reset
(
self
,
max_seqlen
,
max_batch_size
):
self
.
max_seqlen
=
max_seqlen
self
.
max_batch_size
=
max_batch_size
self
.
seqlen_offset
=
0
if
self
.
lengths_per_sample
is
not
None
:
self
.
lengths_per_sample
.
zero_
()
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
def
modify_logits_for_top_k_filtering
(
logits
,
top_k
):
"""Set the logits for none top-k values to -inf. Done in-place."""
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
.
masked_fill_
(
indices_to_remove
,
float
(
"-Inf"
))
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def
modify_logits_for_top_p_filtering
(
logits
,
top_p
):
"""Set the logits for none top-p values to -inf. Done in-place."""
if
top_p
<=
0.0
or
top_p
>=
1.0
:
return
# First sort and calculate cumulative sum of probabilities.
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
False
)
cumulative_probs
=
sorted_logits
.
softmax
(
dim
=-
1
).
cumsum
(
dim
=-
1
)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove
=
cumulative_probs
<=
(
1
-
top_p
)
# scatter sorted tensors to original indexing
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
1
,
sorted_indices
,
sorted_indices_to_remove
)
logits
.
masked_fill_
(
indices_to_remove
,
float
(
"-inf"
))
def
sample
(
logits
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
):
"""Sample from top-k logits.
Arguments:
logits: Tensor of shape (batch_size, vocab_size)
"""
if
top_k
==
1
:
# Short-circuit for greedy decoding
return
logits
.
argmax
(
dim
=-
1
)
else
:
if
top_p
>
0.0
:
assert
top_p
<=
1.0
,
"top-p should be in (0, 1]."
if
top_k
>
0
:
top_k
=
min
(
top_k
,
logits
.
size
(
-
1
))
# Safety check
logits_top
,
indices
=
torch
.
topk
(
logits
,
top_k
,
dim
=-
1
)
if
temperature
!=
1.0
:
logits_top
/=
temperature
modify_logits_for_top_p_filtering
(
logits_top
,
top_p
)
return
indices
[
torch
.
arange
(
indices
.
shape
[
0
],
device
=
indices
.
device
),
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
),
]
else
:
# Clone so that when we modify for top_p we don't change the original logits
logits_top
=
logits
/
temperature
if
temperature
!=
1.0
else
logits
.
clone
()
modify_logits_for_top_p_filtering
(
logits_top
,
top_p
)
return
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
)
@
torch
.
inference_mode
()
def
decode
(
input_ids
,
model
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
eos_token_id
=
None
,
teacher_outputs
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
cg
=
False
,
enable_timing
=
False
,
):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
logits, the next token is taken from the teacher_outputs. Useful for testing.
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
batch_size
,
seqlen_og
=
input_ids
.
shape
teacher_output_len
=
teacher_outputs
.
shape
[
1
]
if
teacher_outputs
is
not
None
else
0
if
cg
:
if
not
hasattr
(
model
,
"_decoding_cache"
):
model
.
_decoding_cache
=
None
model
.
_decoding_cache
=
update_graph_cache
(
model
,
model
.
_decoding_cache
,
batch_size
,
seqlen_og
,
max_length
,
tensor_parallel
=
tensor_parallel
,
)
inference_params
=
model
.
_decoding_cache
.
inference_params
inference_params
.
reset
(
max_length
,
batch_size
)
else
:
inference_params
=
InferenceParams
(
max_seqlen
=
max_length
,
max_batch_size
=
batch_size
)
def
get_logits
(
input_ids
,
inference_params
):
decoding
=
inference_params
.
seqlen_offset
>
0
if
decoding
:
position_ids
=
torch
.
full
(
(
batch_size
,
1
),
inference_params
.
seqlen_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
,
)
else
:
position_ids
=
None
if
not
cg
or
not
decoding
:
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
num_last_tokens
=
1
,
).
logits
.
squeeze
(
dim
=
1
)
else
:
logits
=
model
.
_decoding_cache
.
run
(
input_ids
,
position_ids
,
inference_params
.
seqlen_offset
).
squeeze
(
dim
=
1
)
return
logits
[...,
:
vocab_size
]
if
vocab_size
is
not
None
else
logits
def
sample_tokens
(
logits
,
inference_params
):
if
teacher_outputs
is
None
or
teacher_output_len
<=
inference_params
.
seqlen_offset
:
token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
else
:
token
=
teacher_outputs
[:,
inference_params
.
seqlen_offset
]
# return rearrange(token, "b -> b 1")
return
token
.
unsqueeze
(
1
)
def
should_stop
(
current_token
,
inference_params
):
if
inference_params
.
seqlen_offset
==
0
:
return
False
if
eos_token_id
is
not
None
and
(
current_token
==
eos_token_id
).
all
():
return
True
if
inference_params
.
seqlen_offset
>=
max_length
-
1
:
return
True
return
False
start
=
torch
.
cuda
.
Event
(
enable_timing
=
enable_timing
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
enable_timing
)
if
enable_timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
start
.
record
()
scores
,
sequences
=
[],
[
input_ids
]
while
not
should_stop
(
sequences
[
-
1
],
inference_params
):
scores
.
append
(
get_logits
(
sequences
[
-
1
],
inference_params
))
inference_params
.
seqlen_offset
+=
sequences
[
-
1
].
shape
[
1
]
sequences
.
append
(
sample_tokens
(
scores
[
-
1
],
inference_params
))
if
enable_timing
:
end
.
record
()
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
start
.
elapsed_time
(
end
)):.
0
f
}
ms"
)
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
sequences
=
torch
.
cat
(
sequences
,
dim
=
1
),
scores
=
tuple
(
scores
))
def
sample_speculative
(
logits
,
logits_draft
,
tokens_draft
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
):
"""Algorithm 1 from [1]
[1] Fast Inference from Transformers via Speculative Decoding
Yaniv Leviathan, Matan Kalman, Yossi Matias
https://arxiv.org/abs/2211.17192
Arguments:
logits: Tensor of shape (batch_size, seqlen + 1, vocab_size)
logits_draft: Tensor of shape (batch_size, seqlen, vocab_size)
tokens_draft: Tensor of shape (batch_size, seqlen)
Return:
tokens: Tensor of shape (batch_size, seqlen + 1)
num_generated_tokens: Tensor of shape (batch_size), with value in [1, seqlen + 1].
For each sequence in the batch, the number of valid tokens that were sampled by
speculative sampling.
"""
batch
,
seqlen_p_1
,
vocab_size
=
logits
.
shape
seqlen
=
seqlen_p_1
-
1
assert
logits_draft
.
shape
==
(
batch
,
seqlen
,
vocab_size
)
assert
tokens_draft
.
shape
==
(
batch
,
seqlen
)
assert
tokens_draft
.
dtype
in
[
torch
.
int64
,
torch
.
int32
]
# TODO: if top_k = 1 we can simplify things and only work with indices
if
top_p
>
0.0
:
assert
top_p
<=
1.0
,
"top-p should be in (0, 1]."
# Clone so that when we modify for top_p we don't change the original logits
logits
=
logits
/
temperature
if
temperature
!=
1.0
else
logits
.
clone
()
logits_draft
=
logits_draft
/
temperature
if
temperature
!=
1.0
else
logits_draft
.
clone
()
if
top_k
>
0
:
top_k
=
min
(
top_k
,
logits
.
size
(
-
1
))
# Safety check
modify_logits_for_top_k_filtering
(
logits
,
top_k
)
modify_logits_for_top_k_filtering
(
logits_draft
,
top_k
)
modify_logits_for_top_p_filtering
(
logits
,
top_p
)
modify_logits_for_top_p_filtering
(
logits_draft
,
top_p
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs_draft
=
torch
.
softmax
(
logits_draft
,
dim
=-
1
)
gather
=
lambda
probs
,
tokens
:
rearrange
(
probs
.
gather
(
dim
=-
1
,
index
=
rearrange
(
tokens
,
"... -> ... 1"
)),
"... 1 -> ..."
)
# (batch, seqlen)
accepted
=
torch
.
rand
(
batch
,
seqlen
,
device
=
probs
.
device
)
*
gather
(
probs_draft
,
tokens_draft
)
<=
gather
(
probs
[:,
:
-
1
],
tokens_draft
)
accepted_all
=
accepted
.
all
(
dim
=-
1
)
# (batch,)
first_rejected_idx
=
torch
.
where
(
accepted_all
,
seqlen
,
accepted
.
int
().
argmin
(
dim
=-
1
))
probs_diff
=
torch
.
clamp
(
probs
[:,
:
-
1
]
-
probs_draft
,
min
=
0.0
)
# torch.multinomial can deal with unnormalized probabilities
# probs_diff /= probs_diff.sum(dim=-1, keepdim=True)
resample_probs
=
torch
.
cat
([
probs_diff
,
probs
[:,
-
1
:]],
dim
=
1
)
resample_probs
=
rearrange
(
resample_probs
.
gather
(
dim
=
1
,
index
=
repeat
(
first_rejected_idx
,
"b -> b 1 d"
,
d
=
vocab_size
)),
"b 1 d -> b d"
,
)
resample
=
torch
.
multinomial
(
resample_probs
,
num_samples
=
1
).
squeeze
(
dim
=-
1
)
# (batch,)
tokens
=
F
.
pad
(
tokens_draft
,
(
0
,
1
))
tokens
[:,
first_rejected_idx
]
=
resample
return
tokens
,
first_rejected_idx
+
1
@
torch
.
inference_mode
()
def
decode_speculative
(
input_ids
,
model
,
model_draft
,
max_length
,
speculative_lookahead
=
3
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
eos_token_id
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
cg
=
False
,
enable_timing
=
False
,
debug
=
False
,
):
"""
TD: WIP, for my own understanding, lightly tested. Only support batch_size == 1 for now.
Speculative decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
batch_size
,
seqlen_og
=
input_ids
.
shape
assert
batch_size
==
1
,
"Speculative decoding implementation only supports batch_size=1"
assert
eos_token_id
is
None
,
"Speculative decoding implementation doesn't support eos_token_id"
if
cg
:
if
not
hasattr
(
model_draft
,
"_decoding_cache"
):
model_draft
.
_decoding_cache
=
None
model_draft
.
_decoding_cache
=
update_graph_cache
(
model_draft
,
model_draft
.
_decoding_cache
,
batch_size
,
seqlen_og
,
max_length
,
# draft model needs to process either 1 or 2 tokens at a time
decoding_seqlens
=
(
1
,
2
),
tensor_parallel
=
tensor_parallel
,
)
inference_params_draft
=
model_draft
.
_decoding_cache
.
inference_params
inference_params_draft
.
reset
(
max_length
,
batch_size
)
if
not
hasattr
(
model
,
"_decoding_cache"
):
model
.
_decoding_cache
=
None
model
.
_decoding_cache
=
update_graph_cache
(
model
,
model
.
_decoding_cache
,
batch_size
,
seqlen_og
,
max_length
,
decoding_seqlens
=
range
(
1
,
speculative_lookahead
+
2
),
tensor_parallel
=
tensor_parallel
,
)
inference_params
=
model
.
_decoding_cache
.
inference_params
inference_params
.
reset
(
max_length
,
batch_size
)
else
:
inference_params_draft
=
InferenceParams
(
max_seqlen
=
max_length
,
max_batch_size
=
batch_size
)
inference_params
=
InferenceParams
(
max_seqlen
=
max_length
,
max_batch_size
=
batch_size
)
def
get_logits
(
input_ids
,
inference_params
,
model
,
num_last_tokens
=
1
,
cg
=
False
):
decoding
=
inference_params
.
seqlen_offset
>
0
if
decoding
:
seqlen
=
input_ids
.
shape
[
1
]
# if inference_params.lengths_per_sample is None:
# TODO: in the case of batched decoding where each sequence has a different length,
# we need to compute the position_ids for each sequence using lengths_per_sample
if
True
:
cache_seqlens
=
torch
.
full
(
(
input_ids
.
shape
[
0
],),
inference_params
.
seqlen_offset
,
dtype
=
torch
.
int32
,
device
=
input_ids
.
device
,
)
else
:
cache_seqlens
=
inference_params
.
lengths_per_sample
position_ids
=
cache_seqlens
[:,
None
]
+
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
else
:
position_ids
=
None
if
not
cg
or
not
decoding
:
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
num_last_tokens
=
num_last_tokens
,
).
logits
else
:
# NOTE: careful, CUDA graph is set to have num_last_tokens=input_ids.shape[1].
# This might not be compatible the num_last_tokens used here.
assert
num_last_tokens
<=
input_ids
.
shape
[
1
]
logits
=
model
.
_decoding_cache
.
run
(
input_ids
,
position_ids
,
inference_params
.
seqlen_offset
)[:,
-
num_last_tokens
:]
return
logits
[...,
:
vocab_size
]
if
vocab_size
is
not
None
else
logits
def
sample_tokens
(
input_ids
,
get_logits_fn
,
inference_params
,
sample_fn
,
num_tokens
=
1
):
"""Sample `num_tokens` tokens from the model, given the previous logits.
Also return the logits of the sampled tokens.
Arguments:
input_ids: (batch, seqlen)
Return:
tokens: (batch, num_tokens)
scores: (batch, num_tokens), which contains @previous_logits and the logits of the next
(num_tokens - 1) tokens. The logits of the last token isn't computed.
"""
assert
num_tokens
>=
1
sequences
,
scores
=
[
input_ids
],
[]
for
i
in
range
(
num_tokens
):
scores
.
append
(
get_logits_fn
(
sequences
[
-
1
],
inference_params
)[:,
-
1
])
inference_params
.
seqlen_offset
+=
sequences
[
-
1
].
shape
[
1
]
sequences
.
append
(
sample_fn
(
scores
[
-
1
]).
unsqueeze
(
1
))
return
torch
.
cat
(
sequences
[
1
:],
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
)
sampling_kwargs
=
dict
(
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
sample_fn
=
partial
(
sample
,
**
sampling_kwargs
)
get_logits_main
=
partial
(
get_logits
,
model
=
model
,
cg
=
cg
)
get_logits_draft
=
partial
(
get_logits
,
model
=
model_draft
,
cg
=
cg
)
sample_tokens_main
=
partial
(
sample_tokens
,
get_logits_fn
=
get_logits_main
,
sample_fn
=
sample_fn
,
inference_params
=
inference_params
,
)
sample_tokens_draft
=
partial
(
sample_tokens
,
get_logits_fn
=
get_logits_draft
,
sample_fn
=
sample_fn
,
inference_params
=
inference_params_draft
,
)
if
debug
:
from
transformers
import
AutoTokenizer
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"gpt2"
)
if
enable_timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
sequences
,
scores
=
[
input_ids
],
[]
num_main_model_calls
=
0
num_draft_tokens
=
0
num_accepted_tokens_history
=
[]
if
seqlen_og
>=
max_length
-
1
:
# Don't do speculative sampling, just sample 1 token from the model
tokens
,
scores_new
=
sample_tokens_main
(
input_ids
,
num_tokens
=
1
)
sequences
.
append
(
tokens
)
scores
.
append
(
scores_new
)
else
:
# Sample from draft model, which produces @n_spec_tokens, and @model
# will then use to produce between 1 and 1 + @n_spec_tokens tokens.
# We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length.
n_spec_tokens
=
min
(
speculative_lookahead
,
max_length
-
seqlen_og
-
1
)
tokens_draft
,
scores_draft
=
sample_tokens_draft
(
input_ids
,
num_tokens
=
n_spec_tokens
)
num_draft_tokens
+=
n_spec_tokens
if
debug
:
scores_draft_ref
=
model_draft
(
torch
.
cat
([
input_ids
,
tokens_draft
],
dim
=
1
),
num_last_tokens
=
n_spec_tokens
+
1
).
logits
print
((
scores_draft
-
scores_draft_ref
[:,
:
-
1
]).
abs
().
max
())
# Evaluate the draft tokens with the model
logits
=
get_logits_main
(
torch
.
cat
([
input_ids
,
tokens_draft
],
dim
=
1
),
inference_params
,
num_last_tokens
=
n_spec_tokens
+
1
,
)
num_main_model_calls
+=
1
if
debug
:
logits_ref
=
model
(
torch
.
cat
([
input_ids
,
tokens_draft
],
dim
=
1
),
num_last_tokens
=
n_spec_tokens
+
1
).
logits
print
((
logits
-
logits_ref
).
abs
().
max
())
# breakpoint()
tokens
,
num_generated_tokens
=
sample_speculative
(
logits
,
scores_draft
,
tokens_draft
,
**
sampling_kwargs
)
num_accepted_tokens_history
.
append
(
num_generated_tokens
-
1
)
if
debug
:
print
(
tokens
)
print
(
num_generated_tokens
)
# breakpoint()
# TODO: we're using the fact that batch_size == 1
# TODO: check eos_token_id
sequences
.
append
(
tokens
[:
1
,
:
num_generated_tokens
[
0
]])
scores
.
append
(
logits
[:
1
,
:
num_generated_tokens
[
0
]])
# Note that @model has not evaluated the last sampled token yet, so we'll need to pass
# that in the next time we call @model.
num_generated
=
num_generated_tokens
[
0
].
item
()
inference_params
.
seqlen_offset
=
seqlen_og
+
num_generated
-
1
inference_params_draft
.
seqlen_offset
=
(
inference_params
.
seqlen_offset
-
1
if
num_generated
>
1
else
inference_params
.
seqlen_offset
)
if
debug
:
cur_ids
=
torch
.
cat
([
input_ids
,
sequences
[
-
1
]],
dim
=
1
)
scores_ref
=
model
(
cur_ids
,
num_last_tokens
=
num_generated_tokens
[
0
].
item
()
+
1
).
logits
print
((
scores
[
-
1
]
-
scores_ref
[:,
:
-
1
]).
abs
().
max
())
# breakpoint()
while
True
:
# seqlen_offset is total length generated - 1
if
inference_params
.
seqlen_offset
>=
max_length
-
1
:
break
if
inference_params
.
seqlen_offset
>=
max_length
-
2
:
# Don't do speculative sampling, just sample 1 token from the model
tokens
,
scores_new
=
sample_tokens_main
(
sequences
[
-
1
][:,
-
1
:],
num_tokens
=
1
)
sequences
.
append
(
tokens
)
scores
.
append
(
scores_new
)
break
# Sample from draft model
n_spec_tokens
=
min
(
speculative_lookahead
,
max_length
-
inference_params_draft
.
seqlen_offset
-
2
)
# If the main model accepts all the draft tokens, plus it samples one new token,
# then at the next iteration the draft model need to evaluate the logits of the last draft
# token and the logits of the newly sampled token. So here we pass in the last 2 tokens
# of sequences[-1].
# This exception is when the main model rejects all the draft tokens, in which case we
# will only have 1 token to pass in.
tokens_draft
,
scores_draft
=
sample_tokens_draft
(
sequences
[
-
1
][:,
-
2
:],
num_tokens
=
n_spec_tokens
)
num_draft_tokens
+=
n_spec_tokens
if
debug
:
scores_draft_ref
=
model_draft
(
torch
.
cat
([
cur_ids
,
tokens_draft
],
dim
=
1
),
num_last_tokens
=
n_spec_tokens
+
1
).
logits
print
((
scores_draft
-
scores_draft_ref
[:,
:
-
1
]).
abs
().
max
())
# breakpoint()
# Evaluate the draft tokens with the model
logits
=
get_logits_main
(
torch
.
cat
([
sequences
[
-
1
][:,
-
1
:],
tokens_draft
],
dim
=
1
),
inference_params
,
num_last_tokens
=
n_spec_tokens
+
1
,
)
# (batch, n_spec_tokens + 1, vocab_size)
num_main_model_calls
+=
1
if
debug
:
logits_ref
=
model
(
torch
.
cat
([
cur_ids
,
tokens_draft
],
dim
=
1
),
num_last_tokens
=
n_spec_tokens
+
1
).
logits
print
((
logits
-
logits_ref
).
abs
().
max
())
# breakpoint()
tokens
,
num_generated_tokens
=
sample_speculative
(
logits
,
scores_draft
,
tokens_draft
,
**
sampling_kwargs
)
num_accepted_tokens_history
.
append
(
num_generated_tokens
-
1
)
if
debug
:
print
(
tokens
)
print
(
num_generated_tokens
)
# breakpoint()
sequences
.
append
(
tokens
[:
1
,
:
num_generated_tokens
[
0
]])
scores
.
append
(
logits
[:
1
,
:
num_generated_tokens
[
0
]])
# We've evaluated 1 token from sequences[-1][:, -1:] above, plus
# num_generated_tokens[0].item() - 1 tokens from the draft model.
num_generated
=
num_generated_tokens
[
0
].
item
()
inference_params
.
seqlen_offset
+=
num_generated
inference_params_draft
.
seqlen_offset
=
(
inference_params
.
seqlen_offset
-
1
if
num_generated
>
1
else
inference_params
.
seqlen_offset
)
if
debug
:
cur_ids
=
torch
.
cat
([
cur_ids
,
sequences
[
-
1
]],
dim
=
1
)
scores_ref
=
model
(
cur_ids
,
num_last_tokens
=
num_generated_tokens
[
0
].
item
()
+
1
).
logits
print
((
scores
[
-
1
]
-
scores_ref
[:,
:
-
1
]).
abs
().
max
())
# breakpoint()
if
enable_timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
print
(
f
"Number of calls to main model:
{
num_main_model_calls
}
"
)
print
(
f
"Acceptance rate:
{
torch
.
cat
(
num_accepted_tokens_history
).
sum
().
item
()
/
num_draft_tokens
*
100
:.
2
f
}
%"
)
sequences
=
torch
.
cat
(
sequences
,
dim
=
1
)
scores
=
torch
.
cat
(
scores
,
dim
=
1
)
if
debug
:
scores_ref
=
model
(
sequences
).
logits
print
((
scores
-
scores_ref
[:,
seqlen_og
-
1
:
-
1
]).
abs
().
max
())
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
sequences
=
sequences
,
scores
=
scores
)
class
GenerationMixin
:
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
raise
NotImplementedError
def
generate
(
self
,
input_ids
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
return_dict_in_generate
=
False
,
output_scores
=
False
,
**
kwargs
,
):
output
=
decode
(
input_ids
,
self
,
max_length
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
**
kwargs
)
if
not
output_scores
:
output
.
scores
=
None
return
output
if
return_dict_in_generate
else
output
.
sequences
def
allocate_inference_cache
(
max_batch_size
,
max_seqlen
,
nheads
,
headdim
,
layers
:
Union
[
int
,
Sequence
],
device
,
dtype
=
torch
.
float16
,
):
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
kv_cache_shape
=
(
max_batch_size
,
max_seqlen
,
2
,
nheads
,
headdim
)
if
isinstance
(
layers
,
int
):
layers
=
range
(
layers
)
return
{
i
:
torch
.
empty
(
kv_cache_shape
,
device
=
device
,
dtype
=
dtype
)
for
i
in
layers
}
@
dataclass
class
DecodingCGCache
:
max_batch_size
:
int
=
0
max_seqlen
:
int
=
0
device
=
None
dtype
=
None
callables
:
dict
=
field
(
default_factory
=
dict
)
mempool
=
None
inference_params
:
Optional
[
InferenceParams
]
=
None
run
:
Optional
[
Callable
]
=
None
@
torch
.
inference_mode
()
def
update_graph_cache
(
model
,
cache
,
batch_size
,
seqlen_og
,
max_seqlen
,
decoding_seqlens
=
(
1
,),
tensor_parallel
=
1
,
dtype
=
None
,
n_warmups
=
2
,
):
if
cache
is
None
:
cache
=
DecodingCGCache
()
param_example
=
next
(
iter
(
model
.
parameters
()))
device
=
param_example
.
device
if
dtype
is
None
:
dtype
=
param_example
.
dtype
if
(
(
device
,
dtype
)
!=
(
cache
.
device
,
cache
.
dtype
)
or
batch_size
>
cache
.
max_batch_size
or
max_seqlen
>
cache
.
max_seqlen
):
# Invalidate the cache
cache
.
callables
=
{}
cache
.
mempool
=
None
cache
.
inference_params
=
None
gc
.
collect
()
cache
.
device
,
cache
.
dtype
=
device
,
dtype
cache
.
max_batch_size
,
cache
.
max_seqlen
=
batch_size
,
max_seqlen
if
hasattr
(
model
,
"allocate_inference_cache"
):
inf_cache
=
model
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
)
else
:
headdim
=
getattr
(
model
.
config
,
"head_dim"
,
model
.
config
.
hidden_size
//
model
.
config
.
num_attention_heads
,
)
inf_cache
=
allocate_inference_cache
(
batch_size
,
max_seqlen
,
model
.
config
.
num_attention_heads
//
tensor_parallel
,
headdim
,
model
.
config
.
num_hidden_layers
,
device
,
dtype
,
)
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
device
=
device
)
cache
.
inference_params
=
InferenceParams
(
max_seqlen
=
max_seqlen
,
max_batch_size
=
batch_size
,
seqlen_offset
=
seqlen_og
,
key_value_memory_dict
=
inf_cache
,
lengths_per_sample
=
lengths_per_sample
,
)
cache
.
mempool
=
torch
.
cuda
.
graphs
.
graph_pool_handle
()
for
decoding_seqlen
in
decoding_seqlens
:
if
(
batch_size
,
decoding_seqlen
)
not
in
cache
.
callables
:
cache
.
callables
[
batch_size
,
decoding_seqlen
]
=
capture_graph
(
model
,
cache
.
inference_params
,
batch_size
,
max_seqlen
,
decoding_seqlen
=
decoding_seqlen
,
mempool
=
cache
.
mempool
,
n_warmups
=
n_warmups
,
)
def
dispatch
(
input_ids
,
position_ids
,
seqlen
):
batch_size
,
decoding_seqlen
=
input_ids
.
shape
[:
2
]
return
cache
.
callables
[
batch_size
,
decoding_seqlen
](
input_ids
,
position_ids
,
seqlen
)
cache
.
run
=
dispatch
cache
.
inference_params
.
seqlen_offset
=
0
# Reset so it's not confusing
return
cache
def
capture_graph
(
model
,
inference_params
,
batch_size
,
max_seqlen
,
decoding_seqlen
=
1
,
mempool
=
None
,
n_warmups
=
2
):
device
=
next
(
iter
(
model
.
parameters
())).
device
input_ids
=
torch
.
full
((
batch_size
,
decoding_seqlen
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
torch
.
full
((
batch_size
,
decoding_seqlen
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
seqlen_offset_og
=
inference_params
.
seqlen_offset
inference_params
.
seqlen_offset
=
max_seqlen
-
decoding_seqlen
inference_params
.
lengths_per_sample
[:]
=
inference_params
.
seqlen_offset
# Warmup before capture
s
=
torch
.
cuda
.
Stream
()
s
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
s
):
for
_
in
range
(
n_warmups
):
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
num_last_tokens
=
decoding_seqlen
,
).
logits
s
.
synchronize
()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
# which requires that graph launch and non-captured launch to not overlap (I think,
# that's how I interpret the documentation). I'm not sure if this is required.
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
barrier
()
torch
.
cuda
.
current_stream
().
wait_stream
(
s
)
# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
mempool
):
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
num_last_tokens
=
decoding_seqlen
,
).
logits
def
run
(
new_input_ids
,
new_position_ids
,
seqlen
):
inference_params
.
lengths_per_sample
[:]
=
seqlen
input_ids
.
copy_
(
new_input_ids
)
position_ids
.
copy_
(
new_position_ids
)
graph
.
replay
()
return
logits
.
clone
()
inference_params
.
seqlen_offset
=
seqlen_offset_og
return
run
vllm_flash_attn/utils/pretrained.py
deleted
100644 → 0
View file @
6ac8e63a
import
os
from
functools
import
partial
import
torch
from
safetensors.torch
import
load_file
as
safe_load_file
from
transformers.utils
import
(
SAFE_WEIGHTS_INDEX_NAME
,
SAFE_WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
,
WEIGHTS_NAME
,
)
from
transformers.utils.hub
import
cached_file
,
get_checkpoint_shard_files
def
state_dict_from_pretrained
(
model_name
,
device
=
None
,
dtype
=
None
):
# If not fp32, then we don't want to load directly to the GPU
mapped_device
=
"cpu"
if
dtype
not
in
[
torch
.
float32
,
None
]
else
device
is_sharded
=
False
load_safe
=
False
resolved_archive_file
=
None
weights_path
=
os
.
path
.
join
(
model_name
,
WEIGHTS_NAME
)
weights_index_path
=
os
.
path
.
join
(
model_name
,
WEIGHTS_INDEX_NAME
)
safe_weights_path
=
os
.
path
.
join
(
model_name
,
SAFE_WEIGHTS_NAME
)
safe_weights_index_path
=
os
.
path
.
join
(
model_name
,
SAFE_WEIGHTS_INDEX_NAME
)
if
os
.
path
.
isfile
(
weights_path
):
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
elif
os
.
path
.
isfile
(
weights_index_path
):
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_INDEX_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
is_sharded
=
True
elif
os
.
path
.
isfile
(
safe_weights_path
):
resolved_archive_file
=
cached_file
(
model_name
,
SAFE_WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
load_safe
=
True
elif
os
.
path
.
isfile
(
safe_weights_index_path
):
resolved_archive_file
=
cached_file
(
model_name
,
SAFE_WEIGHTS_INDEX_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
is_sharded
=
True
load_safe
=
True
else
:
# Try loading from HF hub instead of from local files
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
if
resolved_archive_file
is
None
:
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_INDEX_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
if
resolved_archive_file
is
not
None
:
is_sharded
=
True
if
resolved_archive_file
is
None
:
raise
EnvironmentError
(
f
"Model name
{
model_name
}
was not found."
)
if
load_safe
:
loader
=
partial
(
safe_load_file
,
device
=
mapped_device
)
else
:
loader
=
partial
(
torch
.
load
,
map_location
=
mapped_device
)
if
is_sharded
:
# resolved_archive_file becomes a list of files that point to the different
# checkpoint shards in this case.
resolved_archive_file
,
sharded_metadata
=
get_checkpoint_shard_files
(
model_name
,
resolved_archive_file
)
state_dict
=
{}
for
sharded_file
in
resolved_archive_file
:
state_dict
.
update
(
loader
(
sharded_file
))
else
:
state_dict
=
loader
(
resolved_archive_file
)
# Convert dtype before moving to GPU to save memory
if
dtype
is
not
None
:
state_dict
=
{
k
:
v
.
to
(
dtype
=
dtype
)
for
k
,
v
in
state_dict
.
items
()}
state_dict
=
{
k
:
v
.
to
(
device
=
device
)
for
k
,
v
in
state_dict
.
items
()}
return
state_dict
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment