Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cf248976
Unverified
Commit
cf248976
authored
Nov 15, 2024
by
Ke Wen
Committed by
GitHub
Nov 15, 2024
Browse files
Add Tensor Parallel to torch_native_llama (#1876)
parent
e5c67150
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
246 additions
and
82 deletions
+246
-82
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+18
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+16
-0
python/sglang/srt/model_parallel.py
python/sglang/srt/model_parallel.py
+98
-0
python/sglang/srt/models/torch_native_llama.py
python/sglang/srt/models/torch_native_llama.py
+90
-78
test/srt/test_torch_tp.py
test/srt/test_torch_tp.py
+24
-0
No files found.
python/sglang/bench_latency.py
View file @
cf248976
...
@@ -220,8 +220,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
...
@@ -220,8 +220,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
return
reqs
return
reqs
@
torch
.
inference_mode
()
def
_extend
(
reqs
,
model_runner
):
def
extend
(
reqs
,
model_runner
):
batch
=
ScheduleBatch
.
init_new
(
batch
=
ScheduleBatch
.
init_new
(
reqs
=
reqs
,
reqs
=
reqs
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
...
@@ -237,8 +236,15 @@ def extend(reqs, model_runner):
...
@@ -237,8 +236,15 @@ def extend(reqs, model_runner):
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
@
torch
.
inference_mode
()
def
extend
(
reqs
,
model_runner
):
def
decode
(
input_token_ids
,
batch
,
model_runner
):
# Disable inference mode for now when torch TP is applied. We can remove
# this workaround once DTensor adds support for inference mode.
use_inf_mode
=
not
model_runner
.
torch_tp_applied
with
torch
.
inference_mode
(
use_inf_mode
):
return
_extend
(
reqs
,
model_runner
)
def
_decode
(
input_token_ids
,
batch
,
model_runner
):
batch
.
output_ids
=
input_token_ids
batch
.
output_ids
=
input_token_ids
batch
.
prepare_for_decode
()
batch
.
prepare_for_decode
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
@@ -248,6 +254,14 @@ def decode(input_token_ids, batch, model_runner):
...
@@ -248,6 +254,14 @@ def decode(input_token_ids, batch, model_runner):
return
next_token_ids
,
logits_output
.
next_token_logits
return
next_token_ids
,
logits_output
.
next_token_logits
def
decode
(
input_token_ids
,
batch
,
model_runner
):
# Disable inference mode for now when torch TP is applied. We can remove
# this workaround once DTensor adds support for inference mode.
use_inf_mode
=
not
model_runner
.
torch_tp_applied
with
torch
.
inference_mode
(
use_inf_mode
):
return
_decode
(
input_token_ids
,
batch
,
model_runner
)
def
correctness_test
(
def
correctness_test
(
server_args
,
server_args
,
port_args
,
port_args
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
cf248976
...
@@ -148,6 +148,15 @@ class ModelRunner:
...
@@ -148,6 +148,15 @@ class ModelRunner:
min_per_gpu_memory
=
self
.
init_torch_distributed
()
min_per_gpu_memory
=
self
.
init_torch_distributed
()
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
load_model
()
self
.
load_model
()
# Apply torch TP if model supports it
supports_torch_tp
=
getattr
(
self
.
model
,
"supports_torch_tp"
,
False
)
if
self
.
tp_size
>
1
and
supports_torch_tp
:
self
.
apply_torch_tp
()
self
.
torch_tp_applied
=
True
else
:
self
.
torch_tp_applied
=
False
if
server_args
.
lora_paths
is
not
None
:
if
server_args
.
lora_paths
is
not
None
:
self
.
init_lora_manager
()
self
.
init_lora_manager
()
self
.
init_memory_pool
(
self
.
init_memory_pool
(
...
@@ -551,6 +560,13 @@ class ModelRunner:
...
@@ -551,6 +560,13 @@ class ModelRunner:
logger
.
info
(
"Capture cuda graph begin. This can take up to several minutes."
)
logger
.
info
(
"Capture cuda graph begin. This can take up to several minutes."
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
def
apply_torch_tp
(
self
):
logger
.
info
(
f
"Enabling torch tensor parallelism on
{
self
.
tp_size
}
devices."
)
from
sglang.srt.model_parallel
import
tensor_parallel
device_mesh
=
torch
.
distributed
.
init_device_mesh
(
self
.
device
,
(
self
.
tp_size
,))
tensor_parallel
(
self
.
model
,
device_mesh
)
def
forward_decode
(
self
,
forward_batch
:
ForwardBatch
):
def
forward_decode
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
):
return
self
.
cuda_graph_runner
.
replay
(
forward_batch
)
return
self
.
cuda_graph_runner
.
replay
(
forward_batch
)
...
...
python/sglang/srt/model_parallel.py
0 → 100644
View file @
cf248976
"""
Common utilities for torch model parallelism.
"""
from
typing
import
Optional
,
Sequence
import
torch
from
torch.distributed.device_mesh
import
DeviceMesh
try
:
from
torch.distributed.tensor
import
DTensor
,
Shard
except
ImportError
:
# torch 2.4 or older
from
torch.distributed._tensor
import
DTensor
,
Shard
from
torch.distributed._functional_collectives
import
AsyncCollectiveTensor
from
torch.distributed.tensor.parallel
import
(
ColwiseParallel
,
RowwiseParallel
,
parallelize_module
,
)
class
ColwiseParallelSharded
(
ColwiseParallel
):
"""
A version of ColwiseParallel where the local weight has been already
sharded. This is used for the fused wqkv case, where during loading, we
already sharded wq, wk, wv before fusing them.
"""
# Override the _partition_linear_fn in ColwiseParallel
def
_partition_linear_fn
(
self
,
name
,
module
,
device_mesh
):
# colwise shard weight/bias to Shard(0), weight be Shard(0)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
for
name
,
param
in
module
.
named_parameters
():
dtensor
=
DTensor
.
from_local
(
param
,
device_mesh
,
[
Shard
(
0
)])
dist_param
=
torch
.
nn
.
Parameter
(
dtensor
,
requires_grad
=
False
)
module
.
register_parameter
(
name
,
dist_param
)
class
RowwiseParallelMaybeWait
(
RowwiseParallel
):
"""
A version of RowwiseParallel that waits for the output (establish dependency
between comm stream and compute stream in CUDA sense) before going into the
next op. This is needed to workaround the current interaction between
AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
"""
@
staticmethod
def
_prepare_output_fn
(
output_layouts
,
use_local_output
,
mod
,
outputs
,
device_mesh
):
outputs
=
super
(
RowwiseParallelMaybeWait
,
RowwiseParallelMaybeWait
).
_prepare_output_fn
(
output_layouts
,
use_local_output
,
mod
,
outputs
,
device_mesh
)
# wait for the output to be ready
if
isinstance
(
outputs
,
AsyncCollectiveTensor
):
return
outputs
.
wait
()
else
:
return
outputs
def
tensor_parallel
(
module
:
torch
.
nn
.
Module
,
device_mesh
:
Optional
[
DeviceMesh
]
=
None
,
):
"""
Tensor parallelize the model across the given device mesh.
Args:
module (`torch.nn.Module`):
The module to tensor parallelize.
device_mesh (`torch.distributed.DeviceMesh`):
The device mesh to use for tensor parallelism.
"""
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
# No op if `_tp_plan` attribute does not exist under the module.
# This is a helper function to be used with `model.apply` to recursively
# parallelize a model.
def
tplize
(
mod
:
torch
.
nn
.
Module
)
->
None
:
tp_plan
=
getattr
(
mod
,
"_tp_plan"
,
None
)
if
tp_plan
is
None
:
return
for
child_name
,
tp_style
in
tp_plan
.
items
():
submod
=
mod
.
get_submodule
(
child_name
)
if
tp_style
==
"Colwise"
:
parallelize_module
(
submod
,
device_mesh
,
ColwiseParallel
())
elif
tp_style
==
"Rowwise"
:
parallelize_module
(
submod
,
device_mesh
,
RowwiseParallelMaybeWait
())
elif
tp_style
==
"Colwise_Sharded"
:
parallelize_module
(
submod
,
device_mesh
,
ColwiseParallelSharded
())
else
:
raise
ValueError
(
f
"Unknown TP style
{
tp_style
}
"
)
# `apply` is a native method of `nn.Module` that recursively applies a
# function to every submodule.
module
.
apply
(
tplize
)
python/sglang/srt/models/torch_native_llama.py
View file @
cf248976
...
@@ -17,6 +17,31 @@ limitations under the License.
...
@@ -17,6 +17,31 @@ limitations under the License.
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only LLaMA model compatible with HuggingFace weights."""
"""Inference-only LLaMA model compatible with HuggingFace weights."""
# PyTorch Tensor Parallel Available for This Model
"""
This model supports tensor parallelism (TP) using the PyTorch tensor parallel package.
Reference: https://pytorch.org/docs/stable/distributed.tensor.parallel.html
Here is a quick example to enable TP:
```python
from sglang.srt.model_parallel import tensor_parallel
device_mesh = torch.distributed.init_device_mesh("cuda", (tp_size,))
tensor_parallel(model, device_mesh)
```
An end-to-end example can be found in `python/sglang/bench_latency.py`.
You can run it with the following command:
```bash
$ python3 -m sglang.bench_latency --correct
\
--model meta-llama/Meta-Llama-3-8B
\
--json-model-override-args '{"architectures": ["TorchNativeLlamaForCausalLM"]}'
\
--tensor-parallel-size 2
\
--disable-cuda-graph
```
We will eanble CUDA Graph support soon.
"""
import
types
import
types
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
...
@@ -24,7 +49,10 @@ import torch
...
@@ -24,7 +49,10 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -41,35 +69,45 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -41,35 +69,45 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
def
gate_up_proj_weight_loader
(
def
gate_up_proj_weight_loader
(
self
,
self
,
param
:
Parameter
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
int
]
=
None
,
loaded_shard_id
:
int
,
):
):
if
loaded_shard_id
is
None
:
# shard_id: (shard_offset, shard_size)
shard_offsets
:
List
[
Tuple
[
int
,
int
,
int
]]
=
[]
gate_up_offsets
=
{}
for
i
,
output_size
in
enumerate
(
self
.
output_sizes
):
current_shard_offset
=
0
shard_offsets
.
append
((
i
,
current_shard_offset
,
output_size
))
for
i
,
output_size
in
enumerate
(
self
.
output_sizes
):
current_shard_offset
+=
output_size
# Everything shrinks by tp_size if TP enabled
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
output_size
=
output_size
//
tp_size
loaded_weight_shard
=
loaded_weight
.
narrow
(
gate_up_offsets
[
i
]
=
(
current_shard_offset
,
output_size
)
output_dim
,
shard_offset
,
shard_size
current_shard_offset
+=
output_size
)
# Re-size the param to the size after TP
self
.
weight_loader
(
param
,
loaded_weight_shard
,
shard_id
)
if
current_shard_offset
!=
param
.
shape
[
0
]:
else
:
# The clone will free the original, full tensor
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
param
.
data
=
param
.
data
.
narrow
(
0
,
0
,
current_shard_offset
).
clone
()
param_data
=
param
.
data
shard_size
=
loaded_weight
.
shape
[
0
]
# Now load gate or up
shard_offset
=
loaded_shard_id
*
shard_size
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
param_data
=
param_data
.
narrow
(
0
,
shard_offset
,
shard_size
)
param_data
=
param
.
data
assert
param_data
.
shape
==
loaded_weight
.
shape
shard_offset
,
shard_size
=
gate_up_offsets
[
loaded_shard_id
]
param_data
.
copy_
(
loaded_weight
)
param_data
=
param_data
.
narrow
(
0
,
shard_offset
,
shard_size
)
return
loaded_weight
=
loaded_weight
.
narrow
(
0
,
tp_rank
*
shard_size
,
shard_size
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
class
LlamaMLP
(
nn
.
Module
):
class
LlamaMLP
(
nn
.
Module
):
_tp_plan
=
{
"gate_up_proj"
:
"Colwise_Sharded"
,
"down_proj"
:
"Rowwise"
,
}
def
__init__
(
def
__init__
(
self
,
self
,
hidden_size
:
int
,
hidden_size
:
int
,
...
@@ -104,62 +142,44 @@ class LlamaMLP(nn.Module):
...
@@ -104,62 +142,44 @@ class LlamaMLP(nn.Module):
return
x
return
x
def
_get_shard_offset_mapping
(
self
,
loaded_shard_id
:
str
):
shard_offset_mapping
=
{
"q"
:
0
,
"k"
:
self
.
num_heads
*
self
.
head_size
,
"v"
:
(
self
.
num_heads
+
self
.
num_kv_heads
)
*
self
.
head_size
,
"total"
:
(
self
.
num_heads
+
2
*
self
.
num_kv_heads
)
*
self
.
head_size
,
}
return
shard_offset_mapping
.
get
(
loaded_shard_id
)
def
_get_shard_size_mapping
(
self
,
loaded_shard_id
:
str
):
shard_size_mapping
=
{
"q"
:
self
.
num_heads
*
self
.
head_size
,
"k"
:
self
.
num_kv_heads
*
self
.
head_size
,
"v"
:
self
.
num_kv_heads
*
self
.
head_size
,
}
return
shard_size_mapping
.
get
(
loaded_shard_id
)
def
qkv_proj_weight_loader
(
def
qkv_proj_weight_loader
(
self
,
self
,
param
:
Parameter
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
str
]
=
None
,
loaded_shard_id
:
str
,
):
):
if
loaded_shard_id
is
None
:
num_heads
=
self
.
num_heads
//
tp_size
shard_offsets
=
[
num_kv_heads
=
self
.
num_kv_heads
//
tp_size
# (shard_id, shard_offset, shard_size)
# shard_id: (shard_offset, shard_size)
(
"q"
,
0
,
self
.
total_num_heads
*
self
.
head_size
),
qkv_offsets
=
{
(
"q"
:
(
0
,
num_heads
*
self
.
head_size
),
"k"
,
"k"
:
(
num_heads
*
self
.
head_size
,
num_kv_heads
*
self
.
head_size
),
self
.
total_num_heads
*
self
.
head_size
,
"v"
:
(
self
.
total_num_kv_heads
*
self
.
head_size
,
(
num_heads
+
num_kv_heads
)
*
self
.
head_size
,
),
num_kv_heads
*
self
.
head_size
,
(
),
"v"
,
}
(
self
.
total_num_heads
+
self
.
total_num_kv_heads
)
*
self
.
head_size
,
total_size
=
qkv_offsets
[
"v"
][
0
]
+
qkv_offsets
[
"v"
][
1
]
self
.
total_num_kv_heads
*
self
.
head_size
,
# Re-size the param to the size after TP
),
if
total_size
!=
param
.
shape
[
0
]:
]
# The clone will free the original, full tensor
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
param
.
data
=
param
.
data
.
narrow
(
0
,
0
,
total_size
).
clone
()
loaded_weight_shard
=
loaded_weight
.
narrow
(
param
.
output_dim
,
shard_offset
,
shard_size
# Now load q, k or v
)
shard_offset
,
shard_size
=
qkv_offsets
[
loaded_shard_id
]
self
.
weight_loader
(
param
,
loaded_weight_shard
,
shard_id
)
param_data
=
param
.
data
else
:
param_data
=
param_data
.
narrow
(
0
,
shard_offset
,
shard_size
)
shard_offset
=
self
.
_get_shard_offset_mapping
(
loaded_shard_id
)
loaded_weight
=
loaded_weight
.
narrow
(
0
,
tp_rank
*
shard_size
,
shard_size
)
shard_size
=
self
.
_get_shard_size_mapping
(
loaded_shard_id
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
=
param
.
data
param_data
.
copy_
(
loaded_weight
)
param_data
=
param_data
.
narrow
(
0
,
shard_offset
,
shard_size
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
return
class
LlamaAttention
(
nn
.
Module
):
class
LlamaAttention
(
nn
.
Module
):
_tp_plan
=
{
"qkv_proj"
:
"Colwise_Sharded"
,
"o_proj"
:
"Rowwise"
,
}
def
__init__
(
def
__init__
(
self
,
self
,
config
:
LlamaConfig
,
config
:
LlamaConfig
,
...
@@ -176,7 +196,6 @@ class LlamaAttention(nn.Module):
...
@@ -176,7 +196,6 @@ class LlamaAttention(nn.Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
...
@@ -205,20 +224,12 @@ class LlamaAttention(nn.Module):
...
@@ -205,20 +224,12 @@ class LlamaAttention(nn.Module):
(
self
.
total_num_heads
+
2
*
self
.
total_num_kv_heads
)
*
self
.
head_dim
,
(
self
.
total_num_heads
+
2
*
self
.
total_num_kv_heads
)
*
self
.
head_dim
,
bias
=
False
,
bias
=
False
,
)
)
self
.
qkv_proj
.
total_num_heads
=
self
.
total_num_heads
self
.
qkv_proj
.
head_size
=
self
.
head_dim
self
.
qkv_proj
.
head_size
=
self
.
head_dim
self
.
qkv_proj
.
total_num_kv_heads
=
self
.
total_num_kv_heads
self
.
qkv_proj
.
num_heads
=
self
.
total_num_heads
self
.
qkv_proj
.
num_heads
=
self
.
total_num_heads
self
.
qkv_proj
.
num_kv_heads
=
self
.
total_num_kv_heads
self
.
qkv_proj
.
num_kv_heads
=
self
.
total_num_kv_heads
self
.
qkv_proj
.
weight_loader
=
types
.
MethodType
(
self
.
qkv_proj
.
weight_loader
=
types
.
MethodType
(
qkv_proj_weight_loader
,
self
.
qkv_proj
qkv_proj_weight_loader
,
self
.
qkv_proj
)
)
self
.
qkv_proj
.
_get_shard_offset_mapping
=
types
.
MethodType
(
_get_shard_offset_mapping
,
self
.
qkv_proj
)
self
.
qkv_proj
.
_get_shard_size_mapping
=
types
.
MethodType
(
_get_shard_size_mapping
,
self
.
qkv_proj
)
self
.
qkv_proj
.
weight
.
weight_loader
=
self
.
qkv_proj
.
weight_loader
self
.
qkv_proj
.
weight
.
weight_loader
=
self
.
qkv_proj
.
weight_loader
self
.
qkv_proj
.
weight
.
output_dim
=
0
self
.
qkv_proj
.
weight
.
output_dim
=
0
self
.
o_proj
=
torch
.
nn
.
Linear
(
self
.
o_proj
=
torch
.
nn
.
Linear
(
...
@@ -385,6 +396,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
...
@@ -385,6 +396,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
supports_torch_tp
=
True
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
...
test/srt/test_torch_tp.py
0 → 100644
View file @
cf248976
import
unittest
from
sglang.test.test_utils
import
is_in_ci
,
run_bench_latency
class
TestTorchTP
(
unittest
.
TestCase
):
def
test_torch_native_llama
(
self
):
output_throughput
=
run_bench_latency
(
"meta-llama/Meta-Llama-3-8B"
,
[
"--tp"
,
"2"
,
"--json-model-override-args"
,
'{"architectures": ["TorchNativeLlamaForCausalLM"]}'
,
"--disable-cuda-graph"
,
],
)
if
is_in_ci
():
assert
output_throughput
>
0
,
f
"
{
output_throughput
=
}
"
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment