Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ca9844b3
Unverified
Commit
ca9844b3
authored
Nov 05, 2024
by
youkaichao
Committed by
GitHub
Nov 05, 2024
Browse files
[bugfix] fix weak ref in piecewise cudagraph and tractable test (#10048)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
235366fe
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
168 additions
and
25 deletions
+168
-25
tests/compile/piecewise/test_toy_llama.py
tests/compile/piecewise/test_toy_llama.py
+101
-10
vllm/compilation/backends.py
vllm/compilation/backends.py
+67
-15
No files found.
tests/compile/piecewise/test_toy_llama.py
View file @
ca9844b3
"""
Test the piecewise compilation with a simple model, comparing the output
with and without the piecewise compilation.
This is a tractable model, the weights and computation are specially designed
if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
"""
import
os
from
dataclasses
import
dataclass
...
...
@@ -49,6 +53,12 @@ class LlamaConfig:
mlp_size
:
int
=
256
vocab_size
:
int
=
128
num_layers
:
int
=
2
init_value
:
float
=
1.0
tractable_init
:
bool
=
False
random_seed
:
int
=
0
def
__post_init__
(
self
):
assert
self
.
mlp_size
>=
self
.
hidden_size
class
LlamaMLP
(
nn
.
Module
):
...
...
@@ -66,10 +76,23 @@ class LlamaMLP(nn.Module):
bias
=
False
,
)
self
.
gate_up_projection
.
weight
.
data
.
fill_
(
0.0
)
self
.
down_projection
.
weight
.
data
.
fill_
(
0.0
)
if
config
.
tractable_init
:
nn
.
init
.
eye_
(
self
.
gate_up_projection
.
weight
.
data
[:
config
.
mlp_size
])
nn
.
init
.
eye_
(
self
.
gate_up_projection
.
weight
.
data
[
config
.
mlp_size
:])
nn
.
init
.
eye_
(
self
.
down_projection
.
weight
.
data
)
else
:
nn
.
init
.
xavier_normal_
(
self
.
gate_up_projection
.
weight
.
data
,
generator
=
torch
.
Generator
().
manual_seed
(
config
.
random_seed
),
gain
=
0.001
)
nn
.
init
.
xavier_normal_
(
self
.
down_projection
.
weight
.
data
,
generator
=
torch
.
Generator
().
manual_seed
(
config
.
random_seed
),
gain
=
0.001
)
def
forward
(
self
,
x
):
# for tractable_init and positive input, this is
# essentially an elementwise-square
x
=
self
.
gate_up_projection
(
x
)
x
=
x
[:,
:
x
.
size
(
1
)
//
2
]
*
torch
.
nn
.
functional
.
relu
(
x
[:,
x
.
size
(
1
)
//
2
:])
...
...
@@ -84,21 +107,39 @@ class LlamaAttention(nn.Module):
self
.
qkv_projection
=
nn
.
Linear
(
in_features
=
config
.
hidden_size
,
out_features
=
config
.
hidden_size
*
3
,
bias
=
False
,
)
self
.
output_projection
=
nn
.
Linear
(
in_features
=
config
.
hidden_size
,
out_features
=
config
.
hidden_size
,
bias
=
False
,
)
self
.
qkv_projection
.
weight
.
data
.
fill_
(
0.0
)
self
.
output_projection
.
weight
.
data
.
fill_
(
0.0
)
if
config
.
tractable_init
:
nn
.
init
.
eye_
(
self
.
qkv_projection
.
weight
.
data
[:
config
.
hidden_size
])
nn
.
init
.
eye_
(
self
.
qkv_projection
.
weight
.
data
[
config
.
hidden_size
:
2
*
config
.
hidden_size
])
nn
.
init
.
eye_
(
self
.
qkv_projection
.
weight
.
data
[
2
*
config
.
hidden_size
:])
nn
.
init
.
eye_
(
self
.
output_projection
.
weight
.
data
)
else
:
nn
.
init
.
xavier_normal_
(
self
.
qkv_projection
.
weight
.
data
,
generator
=
torch
.
Generator
().
manual_seed
(
config
.
random_seed
),
gain
=
0.001
)
nn
.
init
.
xavier_normal_
(
self
.
output_projection
.
weight
.
data
,
generator
=
torch
.
Generator
().
manual_seed
(
config
.
random_seed
),
gain
=
0.001
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# for tractable_init, this is:
# output = (hidden_states * 3 + positions * 2)
qkv
=
self
.
qkv_projection
(
hidden_states
)
hidden_size
=
qkv
.
size
(
-
1
)
//
3
q
,
k
,
v
=
qkv
.
split
([
hidden_size
,
hidden_size
,
hidden_size
],
dim
=-
1
)
...
...
@@ -126,20 +167,29 @@ class LlamaDecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
For tractable computation:
- if residual is None, the outputs are:
- residual = (hidden_states + 1) * 3 + positions * 2 + hidden_states = hidden_states * 4 + positions * 2 + 3
- hidden_states = (residual + 1) ** 2
- if residual is not None, the outputs are:
- residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
- hidden_states = (residual + 1) ** 2
"""
# noqa
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
hidden_states
/
2
hidden_states
=
hidden_states
+
1
else
:
hidden_states
=
hidden_states
+
residual
residual
=
hidden_states
hidden_states
=
hidden_states
/
2
hidden_states
=
hidden_states
+
1
hidden_states
=
self
.
self_attention
(
positions
=
positions
,
hidden_states
=
hidden_states
)
hidden_states
=
hidden_states
+
residual
residual
=
hidden_states
hidden_states
=
hidden_states
/
2
hidden_states
=
hidden_states
+
1
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
...
...
@@ -156,7 +206,8 @@ class LlamaModel(nn.Module):
self
.
layers
=
nn
.
ModuleList
(
[
LlamaDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_layers
)])
self
.
embedding_tokens
.
weight
.
data
.
fill_
(
0.0
)
# this is the initial value of the hidden states
self
.
embedding_tokens
.
weight
.
data
.
fill_
(
config
.
init_value
)
def
forward
(
self
,
...
...
@@ -170,6 +221,28 @@ class LlamaModel(nn.Module):
return
hidden_states
def
tractable_computation
(
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
config
:
LlamaConfig
,
init_value
:
float
=
1.0
)
->
torch
.
Tensor
:
hidden_states
=
torch
.
ones
(
input_ids
.
size
(
0
),
config
.
hidden_size
,
device
=
input_ids
.
device
,
dtype
=
input_ids
.
dtype
)
*
init_value
# first layer
residual
=
hidden_states
*
4
+
positions
.
unsqueeze
(
1
)
*
2
+
3
hidden_states
=
(
residual
+
1
)
**
2
# following layers
for
_
in
range
(
config
.
num_layers
-
1
):
hidden_states
=
hidden_states
+
residual
residual
=
hidden_states
*
4
+
positions
.
unsqueeze
(
1
)
*
2
+
3
hidden_states
=
(
residual
+
1
)
**
2
return
hidden_states
@
torch
.
inference_mode
def
run_model
(
llama_config
,
use_compile
:
bool
,
...
...
@@ -213,7 +286,15 @@ def run_model(llama_config,
del
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
set_compilation_config
(
None
)
return
output
.
cpu
()
output
=
output
.
cpu
()
if
llama_config
.
tractable_init
:
expected_output
=
tractable_computation
(
input_ids
[:
2
],
positions
[:
2
],
llama_config
).
cpu
()
assert
torch
.
allclose
(
output
,
expected_output
)
else
:
return
output
.
cpu
()
def
test_toy_llama
():
...
...
@@ -222,7 +303,13 @@ def test_toy_llama():
llama_config
=
LlamaConfig
(
hidden_size
=
128
,
mlp_size
=
256
,
vocab_size
=
128
,
num_layers
=
2
)
num_layers
=
12
)
tractable_config
=
LlamaConfig
(
hidden_size
=
128
,
mlp_size
=
256
,
vocab_size
=
128
,
num_layers
=
2
,
tractable_init
=
True
)
outputs
=
[]
with
compilation_counter
.
expect
(
...
...
@@ -233,6 +320,8 @@ def test_toy_llama():
num_cudagraph_caputured
=
0
,
):
outputs
.
append
(
run_model
(
llama_config
,
use_compile
=
False
))
run_model
(
tractable_config
,
use_compile
=
False
)
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
# one graph for the model
num_piecewise_graphs_seen
=
1
,
...
...
@@ -242,6 +331,7 @@ def test_toy_llama():
2
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
outputs
.
append
(
run_model
(
llama_config
,
use_compile
=
True
))
run_model
(
tractable_config
,
use_compile
=
True
)
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
# one graph for the model
...
...
@@ -257,6 +347,7 @@ def test_toy_llama():
):
outputs
.
append
(
run_model
(
llama_config
,
use_compile
=
True
,
split_attn
=
True
))
run_model
(
tractable_config
,
use_compile
=
True
,
split_attn
=
True
)
for
i
in
range
(
1
,
len
(
outputs
)):
assert
torch
.
allclose
(
outputs
[
0
],
outputs
[
i
])
...
...
vllm/compilation/backends.py
View file @
ca9844b3
...
...
@@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import
torch
import
torch.fx
as
fx
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.utils
import
weak_ref_tensors
...
...
@@ -193,6 +194,7 @@ def wrap_inductor(graph,
@
dataclasses
.
dataclass
class
SplitItem
:
submod_name
:
str
graph_id
:
int
is_splitting_graph
:
bool
graph
:
fx
.
GraphModule
...
...
@@ -226,9 +228,7 @@ def split_graph(graph: fx.GraphModule,
outputs
=
[]
# sort the names to make sure the order is deterministic
names
=
[
name
for
(
name
,
module
)
in
split_gm
.
named_modules
()]
names
.
sort
()
for
name
in
names
:
if
"."
in
name
or
name
==
""
:
...
...
@@ -238,7 +238,11 @@ def split_graph(graph: fx.GraphModule,
module
=
getattr
(
split_gm
,
name
)
graph_id
=
int
(
name
.
replace
(
"submod_"
,
""
))
outputs
.
append
(
SplitItem
(
name
,
graph_id
in
split_op_graphs
,
module
))
outputs
.
append
(
SplitItem
(
name
,
graph_id
,
(
graph_id
in
split_op_graphs
),
module
))
# sort by intetger graph_id, rather than string name
outputs
.
sort
(
key
=
lambda
x
:
x
.
graph_id
)
return
split_gm
,
outputs
...
...
@@ -252,6 +256,11 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
It runs the given graph with fake inputs, and compile some
submodules specified by `compile_submod_names` with the given
compilation configs.
NOTE: the order in `compile_submod_names` matters, because
it will be used to determine the order of the compiled piecewise
graphs. The first graph will handle logging, and the last graph
has some special cudagraph output handling.
"""
def
__init__
(
self
,
module
:
torch
.
fx
.
GraphModule
,
...
...
@@ -263,7 +272,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
self
.
compile_submod_names
=
compile_submod_names
self
.
compilation_configs
=
compilation_configs
self
.
graph_pool
=
graph_pool
self
.
have_seen_first_graph
=
False
def
run
(
self
,
*
args
):
fake_args
=
[
...
...
@@ -279,6 +287,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
output
=
super
().
call_module
(
target
,
args
,
kwargs
)
if
target
in
self
.
compile_submod_names
:
index
=
self
.
compile_submod_names
.
index
(
target
)
submod
=
self
.
fetch_attr
(
target
)
sym_shape_indices
=
[
i
for
i
,
x
in
enumerate
(
args
)
if
isinstance
(
x
,
torch
.
SymInt
)
...
...
@@ -288,15 +297,14 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
args
,
self
.
compilation_configs
.
inductor_compile_config
,
runtime_shape
=
None
,
do_logging
=
not
self
.
have_seen_first_graph
,
do_logging
=
index
==
0
,
use_inductor
=
self
.
compilation_configs
.
use_inductor
)
self
.
module
.
__dict__
[
target
]
=
PiecewiseBackend
(
submod
,
self
.
compilation_configs
,
self
.
graph_pool
,
not
self
.
have_seen_first_graph
,
sym_shape_indices
,
submod
,
self
.
compilation_configs
,
self
.
graph_pool
,
index
,
len
(
self
.
compile_submod_names
)
,
sym_shape_indices
,
compiled_graph_for_general_shape
)
self
.
have_seen_first_graph
=
True
compilation_counter
.
num_piecewise_capturable_graphs_seen
+=
1
return
output
...
...
@@ -352,8 +360,9 @@ class VllmBackend:
graph
,
self
.
compilation_configs
.
non_cudagraph_ops
)
from
torch._dynamo.utils
import
lazy_format_graph_code
logger
.
debug
(
"%s"
,
lazy_format_graph_code
(
"stiching module"
,
self
.
split_gm
))
logger
.
debug
(
"%s"
,
lazy_format_graph_code
(
"before split"
,
self
.
graph
))
logger
.
debug
(
"%s"
,
lazy_format_graph_code
(
"after split"
,
self
.
split_gm
))
compilation_counter
.
num_piecewise_graphs_seen
+=
len
(
self
.
piecewise_graphs
)
...
...
@@ -385,12 +394,17 @@ class ConcreteSizeEntry:
cudagraph
:
Optional
[
torch
.
cuda
.
CUDAGraph
]
=
None
output
:
Optional
[
Any
]
=
None
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses
:
Optional
[
List
[
int
]]
=
None
class
PiecewiseBackend
:
def
__init__
(
self
,
graph
:
fx
.
GraphModule
,
compilation_configs
:
CompilationConfig
,
graph_pool
:
Any
,
is_first_graph
:
bool
,
sym_shape_indices
:
List
[
int
],
piecewise_compile_index
:
int
,
total_piecewise_compiles
:
int
,
sym_shape_indices
:
List
[
int
],
compiled_graph_for_general_shape
:
Callable
):
"""
The backend for piecewise compilation.
...
...
@@ -408,7 +422,12 @@ class PiecewiseBackend:
self
.
graph
=
graph
self
.
compilation_configs
=
compilation_configs
self
.
graph_pool
=
graph_pool
self
.
is_first_graph
=
is_first_graph
self
.
piecewise_compile_index
=
piecewise_compile_index
self
.
total_piecewise_compiles
=
total_piecewise_compiles
self
.
is_first_graph
=
piecewise_compile_index
==
0
self
.
is_last_graph
=
(
piecewise_compile_index
==
total_piecewise_compiles
-
1
)
self
.
compile_sizes
:
Set
[
int
]
=
set
(
self
.
compilation_configs
.
compile_sizes
)
...
...
@@ -422,6 +441,8 @@ class PiecewiseBackend:
self
.
sym_shape_indices
=
sym_shape_indices
self
.
is_debugging_mode
=
envs
.
VLLM_LOGGING_LEVEL
==
"DEBUG"
# the entries for different shapes that we need to either
# compile or capture cudagraph
self
.
concrete_size_entries
:
Dict
[
int
,
ConcreteSizeEntry
]
=
{}
...
...
@@ -476,14 +497,45 @@ class PiecewiseBackend:
logger
.
info
(
"Capturing a cudagraph for shape %s"
,
runtime_shape
)
input_addresses
=
[
x
.
data_ptr
()
for
x
in
args
if
isinstance
(
x
,
torch
.
Tensor
)
]
entry
.
input_addresses
=
input_addresses
cudagraph
=
torch
.
cuda
.
CUDAGraph
()
# mind-exploding: carefully manage the reference and memory.
with
torch
.
cuda
.
graph
(
cudagraph
,
pool
=
self
.
graph_pool
):
entry
.
output
=
weak_ref_tensors
(
entry
.
runnable
(
*
args
))
# `output` is managed by pytorch's cudagraph pool
output
=
entry
.
runnable
(
*
args
)
if
self
.
is_last_graph
:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph, because the output of the last graph
# will not be used by any other cuda graph.
output
=
weak_ref_tensors
(
output
)
# here we always use weak ref for the output
# to save memory
entry
.
output
=
weak_ref_tensors
(
output
)
entry
.
cudagraph
=
cudagraph
compilation_counter
.
num_cudagraph_caputured
+=
1
entry
.
cudagraph
=
cudagraph
return
entry
.
output
# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return
output
if
self
.
is_debugging_mode
:
# check if the input addresses are the same
new_input_addresses
=
[
x
.
data_ptr
()
for
x
in
args
if
isinstance
(
x
,
torch
.
Tensor
)
]
assert
new_input_addresses
==
entry
.
input_addresses
,
(
"Input addresses for cudagraphs are different during replay."
f
" Expected
{
entry
.
input_addresses
}
, got
{
new_input_addresses
}
"
)
entry
.
cudagraph
.
replay
()
return
entry
.
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment