Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
c45f3c3a
Unverified
Commit
c45f3c3a
authored
Apr 01, 2023
by
Zhuohan Li
Committed by
GitHub
Apr 01, 2023
Browse files
Optimize tensor parallel execution speed (#17)
parent
7a7929ab
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
103 additions
and
287 deletions
+103
-287
benchmark/benchmark_latency.py
benchmark/benchmark_latency.py
+99
-0
cacheflow/parallel_utils/tensor_parallel/__init__.py
cacheflow/parallel_utils/tensor_parallel/__init__.py
+0
-3
cacheflow/parallel_utils/tensor_parallel/layers.py
cacheflow/parallel_utils/tensor_parallel/layers.py
+4
-284
No files found.
benchmark/benchmark_latency.py
0 → 100644
View file @
c45f3c3a
import
argparse
import
time
from
typing
import
List
from
tqdm
import
tqdm
import
numpy
as
np
import
torch
from
cacheflow.master.simple_frontend
import
SimpleFrontend
from
cacheflow.master.server
import
(
Server
,
add_server_arguments
,
initialize_ray_cluster
)
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.utils
import
get_gpu_memory
,
get_cpu_memory
def
main
(
args
:
argparse
.
Namespace
):
# TODO(zhuohan): Support pipeline parallelism.
assert
args
.
pipeline_parallel_size
==
1
,
(
'Pipeline parallelism is not supported yet.'
)
(
num_nodes
,
num_devices_per_node
,
distributed_init_method
,
all_stage_devices
)
=
(
initialize_ray_cluster
(
address
=
'local'
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
))
# Create a server.
server
=
Server
(
model
=
args
.
model
,
model_path
=
args
.
model_path
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
block_size
=
args
.
block_size
,
dtype
=
args
.
dtype
,
seed
=
args
.
seed
,
swap_space
=
args
.
swap_space
,
max_batch_size
=
args
.
max_batch_size
,
num_nodes
=
num_nodes
,
num_devices_per_node
=
num_devices_per_node
,
distributed_init_method
=
distributed_init_method
,
all_stage_devices
=
all_stage_devices
,
gpu_memory
=
get_gpu_memory
(),
cpu_memory
=
get_cpu_memory
(),
)
# Create a frontend.
frontend
=
SimpleFrontend
(
model_name
=
args
.
model
,
block_size
=
args
.
block_size
,
)
sampling_params_dict
=
{
'n'
:
1
,
'temperature'
:
0.0
,
'top_p'
:
1.0
,
'use_beam_search'
:
False
,
'stop_token_ids'
:
set
(),
'max_num_steps'
:
args
.
output_len
,
}
sampling_params
=
SamplingParams
.
from_dict
(
sampling_params_dict
)
input_token_ids
=
[
0
]
*
args
.
input_len
def
profile_step
(
profile
=
False
):
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
for
_
in
range
(
args
.
batch_size
):
frontend
.
_add_query
(
input_token_ids
,
sampling_params
)
server
.
add_sequence_groups
(
frontend
.
get_inputs
())
start_time
=
time
.
time
()
while
True
:
server
.
step
()
if
not
server
.
has_unfinished_requests
():
break
end_time
=
time
.
time
()
latency
=
end_time
-
start_time
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStop
()
return
latency
print
(
"Warm up step"
)
profile_step
()
# Benchmark.
latencies
=
[]
for
_
in
tqdm
(
range
(
3
),
desc
=
"Profile step"
):
latencies
.
append
(
profile_step
())
print
(
f
'Avg latency:
{
np
.
mean
(
latencies
)
}
seconds'
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'CacheFlow simple server.'
)
parser
=
add_server_arguments
(
parser
)
parser
.
add_argument
(
'--input-len'
,
type
=
int
,
default
=
32
)
parser
.
add_argument
(
'--output-len'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
8
)
args
=
parser
.
parse_args
()
args
.
max_batch_size
=
max
(
args
.
max_batch_size
,
args
.
batch_size
*
args
.
input_len
)
print
(
args
)
main
(
args
)
cacheflow/parallel_utils/tensor_parallel/__init__.py
View file @
c45f3c3a
...
@@ -6,8 +6,6 @@ from .layers import (
...
@@ -6,8 +6,6 @@ from .layers import (
set_defaults_if_not_set_tensor_model_parallel_attributes
,
set_defaults_if_not_set_tensor_model_parallel_attributes
,
copy_tensor_model_parallel_attributes
,
copy_tensor_model_parallel_attributes
,
param_is_not_tensor_parallel_duplicate
,
param_is_not_tensor_parallel_duplicate
,
linear_with_grad_accumulation_and_async_allreduce
)
)
from
.mappings
import
(
from
.mappings
import
(
...
@@ -39,7 +37,6 @@ __all__ = [
...
@@ -39,7 +37,6 @@ __all__ = [
"set_defaults_if_not_set_tensor_model_parallel_attributes"
,
"set_defaults_if_not_set_tensor_model_parallel_attributes"
,
"copy_tensor_model_parallel_attributes"
,
"copy_tensor_model_parallel_attributes"
,
"param_is_not_tensor_parallel_duplicate"
,
"param_is_not_tensor_parallel_duplicate"
,
"linear_with_grad_accumulation_and_async_allreduce"
,
# mappings.py
# mappings.py
"copy_to_tensor_model_parallel_region"
,
"copy_to_tensor_model_parallel_region"
,
"gather_from_tensor_model_parallel_region"
,
"gather_from_tensor_model_parallel_region"
,
...
...
cacheflow/parallel_utils/tensor_parallel/layers.py
View file @
c45f3c3a
...
@@ -3,10 +3,6 @@
...
@@ -3,10 +3,6 @@
# Parts of the code here are adapted from PyTorch
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
# repo: https://github.com/pytorch/pytorch
import
math
import
os
from
typing
import
Optional
import
warnings
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -16,31 +12,20 @@ from torch.nn.parameter import Parameter
...
@@ -16,31 +12,20 @@ from torch.nn.parameter import Parameter
from
cacheflow.parallel_utils.parallel_state
import
(
from
cacheflow.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_group
,
get_global_memory_buffer
,
)
)
from
.mappings
import
(
from
.mappings
import
(
copy_to_tensor_model_parallel_region
,
copy_to_tensor_model_parallel_region
,
gather_from_tensor_model_parallel_region
,
gather_from_tensor_model_parallel_region
,
gather_from_sequence_parallel_region
,
reduce_from_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
scatter_to_tensor_model_parallel_region
,
scatter_to_tensor_model_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
)
)
from
.random
import
get_cuda_rng_tracker
from
.random
import
get_cuda_rng_tracker
from
.utils
import
(
from
.utils
import
(
divide
,
divide
,
split_tensor_along_last_dim
,
VocabUtility
,
VocabUtility
,
)
)
_grad_accum_fusion_available
=
True
try
:
import
fused_weight_gradient_mlp_cuda
except
ImportError
:
_grad_accum_fusion_available
=
False
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
=
{
'tensor_model_parallel'
:
False
,
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
=
{
'tensor_model_parallel'
:
False
,
'partition_dim'
:
-
1
,
'partition_dim'
:
-
1
,
'partition_stride'
:
1
}
'partition_stride'
:
1
}
...
@@ -216,202 +201,6 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -216,202 +201,6 @@ class VocabParallelEmbedding(torch.nn.Module):
return
output
return
output
class
LinearWithGradAccumulationAndAsyncCommunication
(
torch
.
autograd
.
Function
):
"""See linear_with_grad_accumulation_and_async_allreduce"""
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
async_grad_allreduce
,
sequence_parallel
):
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
ctx
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
ctx
.
async_grad_allreduce
=
async_grad_allreduce
ctx
.
sequence_parallel
=
sequence_parallel
if
sequence_parallel
:
world_size
=
get_tensor_model_parallel_world_size
()
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
all_gather_buffer
=
\
get_global_memory_buffer
().
get_tensor
(
dim_size
,
input
.
dtype
,
"mpu"
)
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
input
,
group
=
get_tensor_model_parallel_group
())
total_input
=
all_gather_buffer
else
:
total_input
=
input
output
=
torch
.
matmul
(
total_input
,
weight
.
t
())
if
bias
is
not
None
:
output
=
output
+
bias
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
if
ctx
.
sequence_parallel
:
world_size
=
get_tensor_model_parallel_world_size
()
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
all_gather_buffer
=
\
get_global_memory_buffer
().
get_tensor
(
dim_size
,
input
.
dtype
,
"mpu"
)
handle
=
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
total_input
=
all_gather_buffer
else
:
total_input
=
input
grad_input
=
grad_output
.
matmul
(
weight
)
if
ctx
.
sequence_parallel
:
handle
.
wait
()
# Convert the tensor shapes to 2D for execution compatibility
grad_output
=
grad_output
.
view
(
grad_output
.
shape
[
0
]
*
grad_output
.
shape
[
1
],
grad_output
.
shape
[
2
])
total_input
=
total_input
.
view
(
total_input
.
shape
[
0
]
*
total_input
.
shape
[
1
],
total_input
.
shape
[
2
])
if
ctx
.
async_grad_allreduce
:
# Asynchronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
if
ctx
.
sequence_parallel
:
assert
not
ctx
.
async_grad_allreduce
dim_size
=
list
(
input
.
size
())
sub_grad_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
# reduce_scatter
handle
=
torch
.
distributed
.
_reduce_scatter_base
(
sub_grad_input
,
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# reduce scatter is scheduled before the weight gradient computation
if
ctx
.
gradient_accumulation_fusion
:
if
weight
.
main_grad
.
dtype
==
torch
.
float32
:
fused_weight_gradient_mlp_cuda
.
wgrad_gemm_accum_fp32
(
total_input
,
grad_output
,
weight
.
main_grad
)
elif
weight
.
main_grad
.
dtype
==
torch
.
float16
:
fused_weight_gradient_mlp_cuda
.
wgrad_gemm_accum_fp16
(
total_input
,
grad_output
,
weight
.
main_grad
)
else
:
raise
RuntimeError
(
"Unsupported gradient type for gradient accumulation fusion"
)
grad_weight
=
None
else
:
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
if
ctx
.
sequence_parallel
:
handle
.
wait
()
return
sub_grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
if
ctx
.
async_grad_allreduce
:
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
def
linear_with_grad_accumulation_and_async_allreduce
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
gradient_accumulation_fusion
:
bool
,
async_grad_allreduce
:
bool
,
sequence_parallel_enabled
:
bool
,
)
->
torch
.
Tensor
:
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
This has the option to accumulate the result of backprop
calculation into an existing gradient buffer, preventing the need
to do an additional addition kernel after the gradient
calculation.
Additionally, the tensor parallel all reduce of the input
gradients can be done asynchronously with the calculation of
the weight gradients.
In the case of sequence parallelism, the reduce scatter of the
input gradients is done asynchronously with the calcluation of the
weight gradients.
Use of this module requires that the environment variable
CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective
operations, noted in the code, that should be scheduled before
compute kernels to overlap the communication with the computation,
which is necessary for a speedup but not for correctness so that
ordering isn't imposed by the scheduler. Setting
CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled
in the order they are called.
Arguments:
input (torch.Tensor required): input like torch.nn.functional.linear
weight (torch.Tensor required): weight like torch.nn.functional.linear
bias (torch.Tensor optional): bias like torch.nn.functional.linear
gradient_accumulation_fusion (bool required): Perform the gradient
accumulation fusion, requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use
gradient_accumulation_fusion you must install APEX with
--cpp_ext and --cuda_ext. For example: "pip install
--global-option=
\"
--cpp_ext
\"
--global-option=
\"
--cuda_ext .
\"
" Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion."
async_grad_allreduce (bool required): Do the allreduce of input
gradients asyncronously with the computation of weight
gradients. If sequence_parallel_enabled is True, this must be
False, as no all reduce is performed.
sequence_parallel_enabled (bool required): Indicates that sequence
parallelism is used and thus in the forward pass the input is
all gathered, and the backward pass the input gradients are
reduce scattered.
"""
args
=
[
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
async_grad_allreduce
,
sequence_parallel_enabled
,
]
if
not
linear_with_grad_accumulation_and_async_allreduce
.
warned
:
if
os
.
environ
.
get
(
'CUDA_DEVICE_MAX_CONNECTIONS'
)
!=
"1"
:
if
sequence_parallel_enabled
:
warnings
.
warn
(
"When using sequence parallelism it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup"
)
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
True
if
async_grad_allreduce
:
warnings
.
warn
(
"When using async grad allreduce it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup"
)
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
True
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
*
args
)
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
False
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
"""Linear layer with column parallelism.
"""Linear layer with column parallelism.
...
@@ -436,11 +225,8 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -436,11 +225,8 @@ class ColumnParallelLinear(torch.nn.Module):
skip_bias_add: This was added to enable performance optimations where bias
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
can be fused with other elementwise operations. we skip
adding bias but instead return it.
adding bias but instead return it.
async_tensor_model_parallel_allreduce:
params_dtype:
params_dtype:
use_cpu_initialization:
use_cpu_initialization:
gradient_accumulation_fusion:
sequence_parallel_enabled:
"""
"""
def
__init__
(
self
,
input_size
,
output_size
,
*
,
def
__init__
(
self
,
input_size
,
output_size
,
*
,
...
@@ -448,12 +234,9 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -448,12 +234,9 @@ class ColumnParallelLinear(torch.nn.Module):
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
keep_master_weight_for_test
=
False
,
keep_master_weight_for_test
=
False
,
skip_bias_add
=
False
,
skip_bias_add
=
False
,
async_tensor_model_parallel_allreduce
=
True
,
params_dtype
=
None
,
params_dtype
=
None
,
use_cpu_initialization
=
False
,
use_cpu_initialization
=
False
,
perform_initialization
=
True
,
perform_initialization
=
True
,
gradient_accumulation_fusion
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
):
):
super
(
ColumnParallelLinear
,
self
).
__init__
()
super
(
ColumnParallelLinear
,
self
).
__init__
()
...
@@ -506,37 +289,6 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -506,37 +289,6 @@ class ColumnParallelLinear(torch.nn.Module):
else
:
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
async_tensor_model_parallel_allreduce
=
(
async_tensor_model_parallel_allreduce
and
world_size
>
1
)
if
sequence_parallel_enabled
:
if
world_size
<=
1
:
warnings
.
warn
(
f
"`sequence_parallel_enabled` is set to `True`, but tensor model parallel size is
{
world_size
}
. "
f
"Disabling sequence parallel."
)
sequence_parallel_enabled
=
False
self
.
sequence_parallel_enabled
=
sequence_parallel_enabled
if
gradient_accumulation_fusion
:
if
not
_grad_accum_fusion_available
:
raise
RuntimeError
(
"ColumnParallelLinear was called with gradient_accumulation_fusion set "
"to True but the custom CUDA extension fused_weight_gradient_mlp_cuda "
"module is not found. To use gradient_accumulation_fusion you must "
"install APEX with --cpp_ext and --cuda_ext. For example: "
"pip install --global-option=
\"
--cpp_ext
\"
--global-option=
\"
--cuda_ext .
\"
"
"Note that the extension requires CUDA>=11. Otherwise, you must turn off "
"gradient accumulation fusion."
)
self
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
if
self
.
async_tensor_model_parallel_allreduce
and
self
.
sequence_parallel_enabled
:
raise
RuntimeError
(
"`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` "
"cannot be enabled at the same time."
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
"""Forward of ColumnParallelLinear
"""Forward of ColumnParallelLinear
...
@@ -550,23 +302,11 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -550,23 +302,11 @@ class ColumnParallelLinear(torch.nn.Module):
"""
"""
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
if
self
.
async_tensor_model_parallel_allreduce
or
\
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
self
.
sequence_parallel_enabled
:
input_parallel
=
input_
else
:
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
# Matrix multiply.
output_parallel
=
linear_with_grad_accumulation_and_async_allreduce
(
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
input
=
input_parallel
,
weight
=
self
.
weight
,
bias
=
bias
,
gradient_accumulation_fusion
=
self
.
gradient_accumulation_fusion
,
async_grad_allreduce
=
self
.
async_tensor_model_parallel_allreduce
,
sequence_parallel_enabled
=
self
.
sequence_parallel_enabled
,
)
if
self
.
gather_output
:
if
self
.
gather_output
:
# All-gather across the partitions.
# All-gather across the partitions.
assert
not
self
.
sequence_parallel_enabled
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
else
:
else
:
output
=
output_parallel
output
=
output_parallel
...
@@ -607,8 +347,6 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -607,8 +347,6 @@ class RowParallelLinear(torch.nn.Module):
params_dtype:
params_dtype:
use_cpu_initialization:
use_cpu_initialization:
perform_initialization:
perform_initialization:
gradient_accumulation_fusion:
sequence_parallel_enabled:
"""
"""
def
__init__
(
self
,
input_size
,
output_size
,
*
,
def
__init__
(
self
,
input_size
,
output_size
,
*
,
...
@@ -619,8 +357,6 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -619,8 +357,6 @@ class RowParallelLinear(torch.nn.Module):
params_dtype
=
None
,
params_dtype
=
None
,
use_cpu_initialization
=
False
,
use_cpu_initialization
=
False
,
perform_initialization
=
True
,
perform_initialization
=
True
,
gradient_accumulation_fusion
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
):
):
super
(
RowParallelLinear
,
self
).
__init__
()
super
(
RowParallelLinear
,
self
).
__init__
()
...
@@ -635,10 +371,6 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -635,10 +371,6 @@ class RowParallelLinear(torch.nn.Module):
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
world_size
)
self
.
input_size_per_partition
=
divide
(
input_size
,
world_size
)
self
.
skip_bias_add
=
skip_bias_add
self
.
skip_bias_add
=
skip_bias_add
self
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
self
.
sequence_parallel_enabled
=
sequence_parallel_enabled
if
self
.
sequence_parallel_enabled
and
not
self
.
input_is_parallel
:
raise
RuntimeError
(
"To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`"
)
# Parameters.
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# Note: torch.nn.functional.linear performs XA^T + b and as a result
...
@@ -669,7 +401,6 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -669,7 +401,6 @@ class RowParallelLinear(torch.nn.Module):
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
device
=
torch
.
cuda
.
current_device
(),
self
.
output_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
))
dtype
=
params_dtype
))
setattr
(
self
.
bias
,
'sequence_parallel'
,
sequence_parallel_enabled
)
# Always initialize bias to zero.
# Always initialize bias to zero.
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -693,23 +424,12 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -693,23 +424,12 @@ class RowParallelLinear(torch.nn.Module):
if
self
.
input_is_parallel
:
if
self
.
input_is_parallel
:
input_parallel
=
input_
input_parallel
=
input_
else
:
else
:
assert
not
self
.
sequence_parallel_enabled
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
# Matrix multiply.
output_parallel
=
linear_with_grad_accumulation_and_async_allreduce
(
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
)
input
=
input_parallel
,
weight
=
self
.
weight
,
bias
=
None
,
gradient_accumulation_fusion
=
self
.
gradient_accumulation_fusion
,
async_grad_allreduce
=
False
,
sequence_parallel_enabled
=
False
,
)
# All-reduce across all the partitions.
# All-reduce across all the partitions.
if
self
.
sequence_parallel_enabled
:
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
output_
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
else
:
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
if
not
self
.
skip_bias_add
:
if
not
self
.
skip_bias_add
:
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
output_bias
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment