Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4268ae01
Unverified
Commit
4268ae01
authored
Nov 08, 2022
by
アマデウス
Committed by
GitHub
Nov 08, 2022
Browse files
[kernel] added jit warmup (#1792)
parent
76e64cb6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
81 additions
and
21 deletions
+81
-21
colossalai/kernel/jit/option.py
colossalai/kernel/jit/option.py
+47
-0
colossalai/nn/__init__.py
colossalai/nn/__init__.py
+1
-1
colossalai/nn/layer/parallel_1d/layers.py
colossalai/nn/layer/parallel_1d/layers.py
+33
-19
colossalai/nn/layer/parallel_3d/_utils.py
colossalai/nn/layer/parallel_3d/_utils.py
+0
-1
No files found.
colossalai/kernel/jit/option.py
View file @
4268ae01
import
torch
from
colossalai.nn.layer.colossalai_layer
import
Embedding
,
Linear
from
colossalai.utils
import
get_current_device
from
.bias_dropout_add
import
bias_dropout_add_fused_train
from
.bias_gelu
import
bias_gelu_impl
JIT_OPTIONS_SET
=
False
...
...
@@ -30,3 +36,44 @@ def set_jit_fusion_options():
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
JIT_OPTIONS_SET
=
True
def
warmup_jit_fusion
(
batch_size
:
int
,
hidden_size
:
int
,
seq_length
:
int
=
512
,
vocab_size
:
int
=
32768
,
dtype
:
torch
.
dtype
=
torch
.
float32
):
""" Compilie JIT functions before the main training steps """
embed
=
Embedding
(
vocab_size
,
hidden_size
).
to
(
get_current_device
())
linear_1
=
Linear
(
hidden_size
,
hidden_size
*
4
,
skip_bias_add
=
True
).
to
(
get_current_device
())
linear_2
=
Linear
(
hidden_size
*
4
,
hidden_size
,
skip_bias_add
=
True
).
to
(
get_current_device
())
x
=
torch
.
randint
(
vocab_size
,
(
batch_size
,
seq_length
),
dtype
=
torch
.
long
,
device
=
get_current_device
())
x
=
embed
(
x
)
y
,
y_bias
=
linear_1
(
x
)
z
,
z_bias
=
linear_2
(
y
)
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for
bias_grad
,
input_grad
in
zip
([
True
,
True
],
[
False
,
True
]):
for
_
in
range
(
10
):
bias
=
torch
.
rand_like
(
y_bias
,
dtype
=
dtype
,
device
=
get_current_device
())
input_
=
torch
.
rand_like
(
y
,
dtype
=
dtype
,
device
=
get_current_device
())
bias
.
requires_grad
,
input_
.
requires_grad
=
bias_grad
,
input_grad
bias_gelu_impl
(
input_
,
bias
)
# Warmup fused bias+dropout+add
dropout_rate
=
0.1
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for
input_grad
,
bias_grad
,
residual_grad
in
zip
([
False
,
True
],
[
True
,
True
],
[
True
,
True
]):
for
_
in
range
(
10
):
input_
=
torch
.
rand_like
(
z
,
dtype
=
dtype
,
device
=
get_current_device
())
residual
=
torch
.
rand_like
(
x
,
dtype
=
dtype
,
device
=
get_current_device
())
bias
=
torch
.
rand_like
(
z_bias
,
dtype
=
dtype
,
device
=
get_current_device
())
input_
.
requires_grad
=
input_grad
bias
.
requires_grad
=
bias_grad
residual
.
requires_grad
=
residual_grad
bias_dropout_add_fused_train
(
input_
,
bias
,
residual
,
dropout_rate
)
torch
.
cuda
.
empty_cache
()
colossalai/nn/__init__.py
View file @
4268ae01
from
._ops
import
*
from
.layer
import
*
from
.loss
import
*
from
.lr_scheduler
import
*
from
.metric
import
*
from
.optimizer
import
*
from
._ops
import
*
colossalai/nn/layer/parallel_1d/layers.py
View file @
4268ae01
...
...
@@ -7,6 +7,9 @@ from typing import Callable, Tuple
import
torch
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torch.nn.parameter
import
Parameter
from
colossalai.communication
import
broadcast
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.core
import
global_context
as
gpc
...
...
@@ -14,18 +17,33 @@ from colossalai.global_variables import tensor_parallel_env as env
from
colossalai.kernel
import
LayerNorm
from
colossalai.nn
import
init
as
init
from
colossalai.registry
import
LAYERS
from
colossalai.utils.checkpointing
import
(
broadcast_state_dict
,
gather_tensor_parallel_state_dict
,
partition_tensor_parallel_state_dict
)
from
colossalai.utils.checkpointing
import
(
broadcast_state_dict
,
gather_tensor_parallel_state_dict
,
partition_tensor_parallel_state_dict
,
)
from
colossalai.utils.cuda
import
get_current_device
from
torch
import
Tensor
from
torch.nn.parameter
import
Parameter
from
..vanilla
import
VanillaPatchEmbedding
,
VanillaLayerNorm
from
..base_layer
import
ParallelLayer
from
..colossalai_layer._utils
import
ColossalaiModule
from
..utils
import
divide
,
set_tensor_parallel_attribute_by_partition
from
._utils
import
(
gather_forward_split_backward
,
get_parallel_input
,
reduce_grad
,
reduce_input
,
set_parallel_input
,
split_forward_gather_backward
)
from
..vanilla
import
VanillaLayerNorm
,
VanillaPatchEmbedding
from
._operation
import
linear_with_async_comm
from
._utils
import
(
gather_forward_split_backward
,
get_parallel_input
,
reduce_grad
,
reduce_input
,
set_parallel_input
,
split_forward_gather_backward
,
)
Fast_LN
=
None
try
:
from
apex.contrib.layer_norm.layer_norm
import
FastLayerNorm
Fast_LN
=
FastLayerNorm
except
ImportError
:
pass
@
LAYERS
.
register_module
...
...
@@ -102,19 +120,15 @@ class LayerNorm1D(ColossalaiModule):
]
def
__init__
(
self
,
normalized_shape
:
int
,
eps
=
1e-05
,
bias
=
True
,
dtype
=
None
):
from
apex.normalization
import
FusedLayerNorm
fast_ln_installed
=
False
try
:
from
apex.contrib.layer_norm.layer_norm
import
FastLayerNorm
fast_ln_installed
=
True
except
ImportError
:
pass
if
fast_ln_installed
and
normalized_shape
in
self
.
_fast_ln_supported_sizes
:
norm
=
FastLayerNorm
(
normalized_shape
,
eps
=
eps
).
to
(
dtype
)
if
Fast_LN
is
not
None
and
normalized_shape
in
self
.
_fast_ln_supported_sizes
:
norm
=
Fast_LN
(
normalized_shape
,
eps
=
eps
).
to
(
dtype
)
else
:
norm
=
FusedLayerNorm
(
normalized_shape
,
eps
=
eps
).
to
(
dtype
)
norm
=
None
try
:
from
apex.normalization
import
FusedLayerNorm
norm
=
FusedLayerNorm
(
normalized_shape
,
eps
=
eps
).
to
(
dtype
)
except
ImportError
:
norm
=
LayerNorm
(
normalized_shape
,
eps
=
eps
).
to
(
dtype
)
super
().
__init__
(
norm
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
):
...
...
colossalai/nn/layer/parallel_3d/_utils.py
View file @
4268ae01
...
...
@@ -5,7 +5,6 @@ import torch
from
torch
import
Tensor
from
colossalai.constants
import
INPUT_GROUP_3D
,
INPUT_X_WEIGHT_3D
,
OUTPUT_GROUP_3D
,
OUTPUT_X_WEIGHT_3D
,
WEIGHT_GROUP_3D
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.global_variables
import
tensor_parallel_env
as
env
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment