Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
08f2920e
Commit
08f2920e
authored
Apr 23, 2023
by
zhuwenwen
Browse files
init colossalai, support dtk2304
parent
da3f0934
Pipeline
#237
failed with stages
in 0 seconds
Changes
380
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3889 additions
and
0 deletions
+3889
-0
colossalai/gemini/paramhooks/__init__.py
colossalai/gemini/paramhooks/__init__.py
+3
-0
colossalai/gemini/paramhooks/_param_hookmgr.py
colossalai/gemini/paramhooks/_param_hookmgr.py
+38
-0
colossalai/gemini/placement_policy.py
colossalai/gemini/placement_policy.py
+245
-0
colossalai/gemini/stateful_tensor.py
colossalai/gemini/stateful_tensor.py
+209
-0
colossalai/gemini/stateful_tensor_mgr.py
colossalai/gemini/stateful_tensor_mgr.py
+100
-0
colossalai/gemini/tensor_placement_policy.py
colossalai/gemini/tensor_placement_policy.py
+138
-0
colossalai/gemini/tensor_utils.py
colossalai/gemini/tensor_utils.py
+118
-0
colossalai/global_variables.py
colossalai/global_variables.py
+56
-0
colossalai/initialize.py
colossalai/initialize.py
+472
-0
colossalai/kernel/__init__.py
colossalai/kernel/__init__.py
+3
-0
colossalai/kernel/cuda_native/__init__.py
colossalai/kernel/cuda_native/__init__.py
+3
-0
colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp
colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp
+49
-0
colossalai/kernel/cuda_native/csrc/compat.h
colossalai/kernel/cuda_native/csrc/compat.h
+10
-0
colossalai/kernel/cuda_native/csrc/cpu_adam.cpp
colossalai/kernel/cuda_native/csrc/cpu_adam.cpp
+459
-0
colossalai/kernel/cuda_native/csrc/cpu_adam.h
colossalai/kernel/cuda_native/csrc/cpu_adam.h
+164
-0
colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu
colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu
+191
-0
colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu
...ssalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu
+171
-0
colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu
colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu
+176
-0
colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
...ssalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
+1041
-0
colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu
...ssalai/kernel/cuda_native/csrc/kernels/general_kernels.cu
+243
-0
No files found.
Too many changes to show.
To preserve performance only
380 of 380+
files are displayed.
Plain diff
Email patch
colossalai/gemini/paramhooks/__init__.py
0 → 100644
View file @
08f2920e
from
._param_hookmgr
import
BaseParamHookMgr
__all__
=
[
"BaseParamHookMgr"
]
colossalai/gemini/paramhooks/_param_hookmgr.py
0 → 100644
View file @
08f2920e
from
typing
import
Callable
,
List
import
torch
import
functools
class
BaseParamHookMgr
(
object
):
def
__init__
(
self
,
param_list
:
List
[
torch
.
nn
.
Parameter
])
->
None
:
r
"""
register backward hook on every parameters of module
"""
self
.
_param_list
=
param_list
self
.
_hook_list
=
[]
def
register_backward_hooks
(
self
,
hook_call
:
Callable
)
->
None
:
r
"""
The hook_call will be called every time a gradient with respect to the a param in self.param_list
is computed.
The hook should have the following signature:
```
hook(param, grad) -> Tensor or None
```
"""
if
not
torch
.
is_grad_enabled
():
return
# don't register grad hooks if grad isn't enabled
for
p
in
self
.
_param_list
:
if
p
.
requires_grad
and
not
hasattr
(
p
,
'_base_param_hook'
):
handle
=
p
.
register_hook
(
functools
.
partial
(
hook_call
,
p
))
p
.
_base_param_hook
=
handle
def
remove_hooks
(
self
)
->
None
:
"""
Remove hooks from model parameters.
"""
for
p
in
self
.
_param_list
:
if
p
.
requires_grad
and
hasattr
(
p
,
'_base_param_hook'
):
p
.
_base_param_hook
.
remove
()
colossalai/gemini/placement_policy.py
0 → 100644
View file @
08f2920e
import
functools
from
abc
import
ABC
,
abstractmethod
from
time
import
time
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
colossalai.gemini.chunk
import
Chunk
,
ChunkManager
from
colossalai.gemini.memory_tracer
import
ChunkMemStatsCollector
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
class
PlacementPolicy
(
ABC
):
need_mem_stats
:
bool
=
False
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
self
.
chunk_manager
=
chunk_manager
self
.
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
mem_stats_collector
@
abstractmethod
def
evict_tensors
(
self
,
can_evict_chunks
:
List
[
Chunk
],
**
kwargs
)
->
Tuple
[
int
,
float
]:
raise
NotImplementedError
@
staticmethod
def
get_default_device
()
->
torch
.
device
:
return
torch
.
device
(
'cpu'
)
class
CPUPlacementPolicy
(
PlacementPolicy
):
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
def
evict_tensors
(
self
,
can_evict_chunks
:
List
[
Chunk
],
**
kwargs
)
->
Tuple
[
int
,
float
]:
volume
=
0
start
=
time
()
for
chunk
in
can_evict_chunks
:
self
.
chunk_manager
.
release_chunk
(
chunk
)
self
.
chunk_manager
.
move_chunk
(
chunk
,
torch
.
device
(
'cpu'
))
volume
+=
chunk
.
chunk_mem
return
volume
,
time
()
-
start
class
CUDAPlacementPolicy
(
PlacementPolicy
):
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
assert
torch
.
cuda
.
is_available
(),
'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
def
evict_tensors
(
self
,
can_evict_chunks
:
List
[
Chunk
],
**
kwargs
)
->
Tuple
[
int
,
float
]:
return
0
,
0
@
staticmethod
def
get_default_device
()
->
torch
.
device
:
return
get_current_device
()
class
AutoPlacementPolicy
(
PlacementPolicy
):
need_mem_stats
:
bool
=
True
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
_warmup_non_model_data_ratio
:
float
=
0.8
_steady_cuda_cap_ratio
:
float
=
0.9
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
def
evict_tensors
(
self
,
can_evict_chunks
:
List
[
Chunk
],
cuda_demand
:
int
=
0
,
warmup
:
bool
=
True
,
compute_list
:
Optional
[
List
[
Tuple
[
Chunk
,
...]]]
=
None
,
compute_idx
:
int
=
0
,
**
kwargs
)
->
Tuple
[
int
,
float
]:
"""
Evict tensors from CUDA device.
Args:
can_evict_chunks (List[StatefulTensor]): the list of tensors that can be evicted.
cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0.
warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True.
compute_list (List[StatefulTensor], optional): TODO. Defaults to [].
compute_idx (int, optional): the idx of computing device. Defaults to 0.
Raises:
RuntimeError:
Returns:
int: the volume of memory that is evicted
"""
start
=
time
()
cuda_capacity
=
colo_device_memory_capacity
(
get_current_device
())
used_cuda_model_data
=
self
.
chunk_manager
.
total_mem
[
'cuda'
]
if
warmup
:
# We designate a part of CUDA memory for model data in warmup iterations.
max_cuda_non_model_data_per_period
=
cuda_capacity
*
AutoPlacementPolicy
.
_warmup_non_model_data_ratio
else
:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period
=
self
.
mem_stats_collector
.
next_period_non_model_data_usage
(
'cuda'
)
cuda_capacity
*=
AutoPlacementPolicy
.
_steady_cuda_cap_ratio
total_cuda_model_data
=
cuda_capacity
-
max_cuda_non_model_data_per_period
avail_cuda_model_data
=
total_cuda_model_data
-
used_cuda_model_data
freed_cuda_model_data
=
0
if
avail_cuda_model_data
<
cuda_demand
:
# Move cuda_demand - avail_cuda_model_data volume of tensors
# to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
to_free_cuda_model_data
=
cuda_demand
-
avail_cuda_model_data
to_free_chunks
=
can_evict_chunks
if
not
warmup
:
to_free_chunks
=
self
.
_sort_can_evict_chunks
(
tuple
(
to_free_chunks
),
compute_idx
,
tuple
(
compute_list
))
# print(self._sort_can_evict_chunks.cache_info())
for
chunk
in
to_free_chunks
:
if
freed_cuda_model_data
>=
to_free_cuda_model_data
:
break
self
.
chunk_manager
.
release_chunk
(
chunk
)
self
.
chunk_manager
.
move_chunk
(
chunk
,
torch
.
device
(
'cpu'
))
freed_cuda_model_data
+=
chunk
.
chunk_mem
if
freed_cuda_model_data
<
to_free_cuda_model_data
:
raise
RuntimeError
(
f
"Adjust layout failed! No enough CUDA memory! "
f
"Need
{
to_free_cuda_model_data
}
, freed
{
freed_cuda_model_data
}
"
)
return
freed_cuda_model_data
,
time
()
-
start
@
staticmethod
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_sort_can_evict_chunks
(
can_evict_chunks
:
tuple
,
compute_idx
:
int
,
compute_list
:
tuple
)
->
list
:
next_compute_idx
=
{
chunk
:
len
(
compute_list
)
for
chunk
in
can_evict_chunks
}
for
i
in
range
(
len
(
compute_list
)
-
1
,
compute_idx
,
-
1
):
for
chunk
in
compute_list
[
i
]:
if
chunk
in
next_compute_idx
:
next_compute_idx
[
chunk
]
=
i
next_compute_idx
=
sorted
(
next_compute_idx
.
items
(),
key
=
lambda
pair
:
pair
[
1
],
reverse
=
True
)
return
[
t
for
(
t
,
idx
)
in
next_compute_idx
]
@
staticmethod
def
set_warmup_non_model_data_ratio
(
ratio
:
float
)
->
None
:
ratio
=
float
(
ratio
)
assert
0.0
<
ratio
<
1.0
AutoPlacementPolicy
.
_warmup_non_model_data_ratio
=
ratio
@
staticmethod
def
set_steady_cuda_cap_ratio
(
ratio
:
float
)
->
None
:
ratio
=
float
(
ratio
)
assert
0.0
<
ratio
<
1.0
AutoPlacementPolicy
.
_steady_cuda_cap_ratio
=
ratio
class
ConstPlacementPolicy
(
PlacementPolicy
):
need_mem_stats
:
bool
=
False
_accessed_memory_boundary
=
512
*
1024
**
2
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
def
evict_tensors
(
self
,
can_evict_chunks
:
List
[
Chunk
],
cuda_demand
:
int
=
0
,
warmup
:
bool
=
True
,
compute_list
:
Optional
[
List
[
Tuple
[
Chunk
,
...]]]
=
None
,
compute_idx
:
int
=
0
,
**
kwargs
)
->
Tuple
[
int
,
float
]:
"""
See the docstrings in the class `AutoPlacementPolicy`.
"""
start
=
time
()
used_accessed_memory
=
self
.
chunk_manager
.
accessed_mem
avail_accessed_memory
=
ConstPlacementPolicy
.
_accessed_memory_boundary
-
used_accessed_memory
freed_accessed_memory
=
0
if
avail_accessed_memory
<
cuda_demand
:
to_free_memory
=
cuda_demand
-
avail_accessed_memory
to_free_chunks
=
can_evict_chunks
if
not
warmup
:
# sort all chunks
to_free_chunks
=
self
.
_sort_can_evict_chunks
(
tuple
(
to_free_chunks
),
compute_idx
,
tuple
(
compute_list
))
for
chunk
in
to_free_chunks
:
if
freed_accessed_memory
>=
to_free_memory
:
break
self
.
chunk_manager
.
release_chunk
(
chunk
)
self
.
chunk_manager
.
move_chunk
(
chunk
,
torch
.
device
(
'cpu'
))
freed_accessed_memory
+=
chunk
.
chunk_mem
if
freed_accessed_memory
<
to_free_memory
:
raise
RuntimeError
(
f
"Adjust layout failed! No enough CUDA memory! "
f
"Need
{
to_free_memory
}
, freed
{
freed_accessed_memory
}
"
)
return
freed_accessed_memory
,
time
()
-
start
@
staticmethod
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_sort_can_evict_chunks
(
can_evict_chunks
:
tuple
,
compute_idx
:
int
,
compute_list
:
tuple
)
->
list
:
next_compute_idx
=
{
chunk
:
len
(
compute_list
)
for
chunk
in
can_evict_chunks
}
for
i
in
range
(
len
(
compute_list
)
-
1
,
compute_idx
,
-
1
):
for
chunk
in
compute_list
[
i
]:
if
chunk
in
next_compute_idx
:
next_compute_idx
[
chunk
]
=
i
next_compute_idx
=
sorted
(
next_compute_idx
.
items
(),
key
=
lambda
pair
:
pair
[
1
],
reverse
=
True
)
return
[
t
for
(
t
,
idx
)
in
next_compute_idx
]
@
staticmethod
def
set_const_memory_boundary
(
cuda_memory_mb
:
int
)
->
None
:
boundary
=
int
(
cuda_memory_mb
*
1024
**
2
)
assert
boundary
>
0
ConstPlacementPolicy
.
_accessed_memory_boundary
=
boundary
class
PlacementPolicyFactory
:
policies
:
Dict
[
str
,
Type
[
PlacementPolicy
]]
=
{
'cpu'
:
CPUPlacementPolicy
,
'cuda'
:
CUDAPlacementPolicy
,
'auto'
:
AutoPlacementPolicy
,
'const'
:
ConstPlacementPolicy
}
@
staticmethod
def
create
(
policy_name
:
str
)
->
Type
[
PlacementPolicy
]:
if
policy_name
not
in
PlacementPolicyFactory
.
policies
:
raise
TypeError
(
f
"Unknown tensor placement policy
{
policy_name
}
"
)
return
PlacementPolicyFactory
.
policies
[
policy_name
]
@
staticmethod
def
get_polocy_names
():
return
tuple
(
PlacementPolicyFactory
.
policies
.
keys
())
@
staticmethod
def
get_default_device
(
policy_name
:
str
)
->
torch
.
device
:
policy_cls
=
PlacementPolicyFactory
.
create
(
policy_name
)
return
policy_cls
.
get_default_device
()
colossalai/gemini/stateful_tensor.py
0 → 100644
View file @
08f2920e
from
enum
import
Enum
from
typing
import
Optional
import
torch
from
typing
import
Union
from
colossalai.gemini.gemini_context
import
GeminiMemoryManager
def
sizeof_tensor
(
tensor
:
torch
.
Tensor
):
return
tensor
.
numel
()
*
tensor
.
element_size
()
class
TensorState
(
Enum
):
FREE
=
0
HOLD
=
1
HOLD_AFTER_FWD
=
2
HOLD_AFTER_BWD
=
3
COMPUTE
=
4
class
StatefulTensor
(
object
):
"""A Structure stores a Torch Tensor and labeled states.
Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
# Global Stateful Tensor Manager
GST_MGR
=
GeminiMemoryManager
(
TensorState
)
def
__init__
(
self
,
maybe_tensor
:
Optional
[
torch
.
Tensor
],
state
:
Optional
[
TensorState
]
=
TensorState
.
HOLD
)
->
None
:
self
.
_state
=
state
self
.
_payload
=
None
self
.
_payload_size
=
0
# byte size of current payload
StatefulTensor
.
GST_MGR
.
register_new_instance
()
if
self
.
_state
==
TensorState
.
FREE
:
# when the state is free, payload should be None
assert
maybe_tensor
is
None
,
f
"payload has to None if state is
{
self
.
_state
}
"
else
:
# otherwise, payload should not be None
assert
maybe_tensor
is
not
None
,
f
"payload can't be None if state is
{
self
.
_state
}
"
self
.
_payload
=
maybe_tensor
self
.
_payload_size
=
sizeof_tensor
(
maybe_tensor
)
self
.
__trans_state_update
(
TensorState
.
FREE
,
state
)
def
data_ptr
(
self
):
if
self
.
_payload
is
None
:
return
0
# if a tensor has no storage, 0 should be returned
return
self
.
_payload
.
data_ptr
()
def
set_null
(
self
)
->
None
:
# notice that free stateful tensor do not need to become null again
if
self
.
state
!=
TensorState
.
FREE
:
self
.
__trans_state_update
(
self
.
state
,
TensorState
.
FREE
)
self
.
__release
()
def
is_null
(
self
)
->
bool
:
if
self
.
state
==
TensorState
.
FREE
:
# check sanity here
assert
self
.
payload
is
None
return
True
return
False
def
trans_state
(
self
,
state
:
TensorState
)
->
None
:
if
self
.
state
==
TensorState
.
FREE
:
# free stateful tensor can't change state
assert
state
==
TensorState
.
FREE
,
"Free stateful tensor can't change to other states"
return
self
.
__trans_state_update
(
self
.
state
,
state
)
if
state
==
TensorState
.
FREE
:
self
.
__release
()
else
:
self
.
_state
=
state
def
move_to
(
self
,
device
:
Union
[
torch
.
device
,
int
]):
assert
self
.
state
is
not
TensorState
.
FREE
,
"Can't move free stateful tensor"
if
not
isinstance
(
device
,
torch
.
device
):
to_device
=
torch
.
device
(
'cuda'
,
device
)
else
:
to_device
=
device
from_device_type
=
self
.
device
.
type
if
from_device_type
==
to_device
.
type
:
# from device == to device
return
# update manager's information
self
.
__trans_device_update
(
from_device_type
,
to_device
.
type
)
self
.
payload
.
data
=
self
.
payload
.
data
.
to
(
to_device
)
def
payload_copy
(
self
,
tensor
)
->
None
:
self
.
_payload
.
view
(
-
1
).
copy_
(
tensor
.
view
(
-
1
))
def
payload_reset
(
self
,
tensor
)
->
None
:
assert
tensor
is
not
None
,
"Can't reset None for stateful tensors, please use set_null() instead"
if
self
.
payload
is
not
None
:
# release old payload
self
.
__trans_state_update
(
self
.
state
,
TensorState
.
FREE
)
else
:
# otherwise, set the state to HOLD for new payload
self
.
_state
=
TensorState
.
HOLD
del
self
.
_payload
self
.
_payload
=
tensor
self
.
_payload_size
=
sizeof_tensor
(
tensor
)
# record new payload
self
.
__trans_state_update
(
TensorState
.
FREE
,
self
.
state
)
def
payload_relay
(
self
,
rhs
):
# relay the payload of rhs to current stateful tensor
# can't support null relay right now
assert
not
rhs
.
is_null
()
# now this function only support stateful tensor that has zero-length payload
# because it doesn't require memory manager updating
# you can extend this function by yourself
assert
self
.
payload_size
==
0
self
.
_payload
=
rhs
.
payload
self
.
_payload_size
=
rhs
.
payload_size
self
.
_state
=
TensorState
.
HOLD
self
.
__trans_state_update
(
rhs
.
state
,
TensorState
.
HOLD
)
rhs
.
__release
()
@
property
def
payload
(
self
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
_payload
@
property
def
payload_size
(
self
)
->
int
:
return
self
.
_payload_size
@
property
def
state
(
self
)
->
TensorState
:
return
self
.
_state
@
property
def
device
(
self
)
->
torch
.
device
:
return
self
.
_payload
.
device
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
return
self
.
_payload
.
dtype
@
property
def
shape
(
self
):
return
self
.
_payload
.
shape
def
to
(
self
,
device
:
torch
.
device
):
raise
RuntimeError
(
"Use move_to(...) instead of call .to() on StatefulTensor"
)
def
to_
(
self
,
device
:
torch
.
device
):
raise
RuntimeError
(
"Use move_to(...) instead of call .to_() on StatefulTensor"
)
def
__release
(
self
):
# release current payload
# shouldn't be visible to users
self
.
_state
=
TensorState
.
FREE
self
.
_payload
=
None
self
.
_payload_size
=
0
def
__trans_state_update
(
self
,
from_state
:
TensorState
,
to_state
:
TensorState
):
"""Update global manager when changing the state of a tensor
"""
manager
=
StatefulTensor
.
GST_MGR
size
=
self
.
payload_size
device_type
=
self
.
device
.
type
if
from_state
!=
TensorState
.
FREE
:
manager
.
state_mem
[
device_type
][
from_state
]
-=
size
else
:
# when from_state is FREE, the tensor is new to manager
# we should add its memory
manager
.
total_mem
[
device_type
]
+=
size
if
to_state
!=
TensorState
.
FREE
:
manager
.
state_mem
[
device_type
][
to_state
]
+=
size
else
:
# when to_state is FREE, the tensor will be deleted soon
# we should sub its memory
manager
.
total_mem
[
device_type
]
-=
size
def
__trans_device_update
(
self
,
from_type
:
str
,
to_type
:
str
):
"""Update global manager when changing the device of a tensor
"""
manager
=
StatefulTensor
.
GST_MGR
size
=
self
.
payload_size
state
=
self
.
state
# update aggregated information
manager
.
total_mem
[
from_type
]
-=
size
manager
.
total_mem
[
to_type
]
+=
size
# update the information of each state
manager
.
state_mem
[
from_type
][
state
]
-=
size
manager
.
state_mem
[
to_type
][
state
]
+=
size
def
__del__
(
self
):
self
.
set_null
()
StatefulTensor
.
GST_MGR
.
delete_instance
()
del
self
colossalai/gemini/stateful_tensor_mgr.py
0 → 100644
View file @
08f2920e
import
functools
import
torch
import
types
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.gemini.tensor_utils
import
colo_model_data_tensor_move_inline
,
colo_tensor_mem_usage
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
,
TensorState
from
colossalai.gemini.tensor_placement_policy
import
TensorPlacementPolicy
from
typing
import
List
from
colossalai.logging
import
get_dist_logger
from
time
import
time
class
StatefulTensorMgr
(
object
):
"""
Stateful Tensor Manager, inspired from PatrickStar
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
def
__init__
(
self
,
tensor_placement_policy
:
TensorPlacementPolicy
)
->
None
:
self
.
_tensor_placement_policy
:
TensorPlacementPolicy
=
tensor_placement_policy
self
.
_stateful_tensor_list
:
List
[
StatefulTensor
]
=
[]
self
.
_compute_list
:
List
[
StatefulTensor
]
=
[]
self
.
_compute_idx
:
int
=
-
1
self
.
_cpu_gpu_move_volume
=
0
self
.
_layout_time
=
0
self
.
_evict_time
=
0
self
.
_warmup
=
True
def
register_stateful_tensor_list
(
self
,
tensor_list
:
List
[
StatefulTensor
])
->
None
:
assert
self
.
_stateful_tensor_list
==
[],
"Can't register stateful tensors for manager twice"
self
.
_stateful_tensor_list
=
tensor_list
for
t
in
self
.
_stateful_tensor_list
:
assert
isinstance
(
t
,
StatefulTensor
)
t
.
trans_state
=
types
.
MethodType
(
functools
.
partial
(
self
.
_trans_state
,
t
.
trans_state
),
t
)
def
start_iter
(
self
):
pass
def
finish_iter
(
self
):
"""This function must be called when each iteration finishes
"""
self
.
_warmup
=
False
self
.
_compute_idx
=
-
1
self
.
_cpu_gpu_move_volume
=
0
self
.
_layout_time
=
0
self
.
_evict_time
=
0
def
adjust_layout
(
self
)
->
None
:
""" Adjust the layout of statefuil tensor according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE
cuda_demand
=
StatefulTensor
.
GST_MGR
.
state_mem
[
'cpu'
][
TensorState
.
COMPUTE
]
start
=
time
()
move_to_cuda_tensor_list
,
hold_cuda_tensor_list
=
self
.
_get_layout_info
(
self
.
_compute_idx
,
self
.
_warmup
)
self
.
_layout_time
+=
time
()
-
start
vol
,
evict_time
=
self
.
_tensor_placement_policy
.
evict_tensors
(
hold_cuda_tensor_list
,
cuda_demand
=
cuda_demand
,
warmup
=
self
.
_warmup
,
compute_list
=
self
.
_compute_list
,
compute_idx
=
self
.
_compute_idx
)
self
.
_cpu_gpu_move_volume
+=
vol
self
.
_evict_time
+=
evict_time
# move COMPUTE tensors to CUDA
self
.
_cpu_gpu_move_volume
+=
cuda_demand
for
t
in
move_to_cuda_tensor_list
:
colo_model_data_tensor_move_inline
(
t
,
get_current_device
())
@
property
def
cpu_gpu_move_volume
(
self
):
return
self
.
_cpu_gpu_move_volume
def
_trans_state
(
self
,
trans_state_func
,
stateful_tensor
,
state
):
trans_state_func
(
state
)
if
state
==
TensorState
.
COMPUTE
:
self
.
_compute_idx
+=
1
if
self
.
_warmup
:
self
.
_compute_list
.
append
(
stateful_tensor
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_get_layout_info
(
self
,
compute_idx
:
int
,
warmup
:
bool
):
move_to_cuda_tensor_list
=
[]
hold_cuda_tensor_list
=
[]
for
tensor
in
self
.
_stateful_tensor_list
:
if
tensor
.
state
==
TensorState
.
FREE
:
continue
if
tensor
.
device
.
type
==
'cuda'
:
if
tensor
.
state
in
[
TensorState
.
HOLD
,
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
HOLD_AFTER_FWD
]:
hold_cuda_tensor_list
.
append
(
tensor
)
elif
tensor
.
device
.
type
==
'cpu'
:
if
tensor
.
state
==
TensorState
.
COMPUTE
:
move_to_cuda_tensor_list
.
append
(
tensor
)
else
:
raise
RuntimeError
return
move_to_cuda_tensor_list
,
hold_cuda_tensor_list
colossalai/gemini/tensor_placement_policy.py
0 → 100644
View file @
08f2920e
from
abc
import
ABC
,
abstractmethod
from
time
import
time
from
typing
import
List
,
Optional
import
torch
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
colossalai.gemini.tensor_utils
import
colo_model_data_tensor_move_inline
,
colo_tensor_mem_usage
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
from
colossalai.gemini.memory_tracer
import
MemStatsCollector
from
typing
import
Type
import
functools
class
TensorPlacementPolicy
(
ABC
):
def
__init__
(
self
,
device
:
Optional
[
torch
.
device
],
mem_stats_collector
:
Optional
[
MemStatsCollector
]
=
None
)
->
None
:
self
.
device
:
Optional
[
torch
.
device
]
=
device
self
.
mem_stats_collector
:
Optional
[
MemStatsCollector
]
=
mem_stats_collector
@
abstractmethod
def
evict_tensors
(
self
,
hold_cuda_tensor_list
:
List
[
StatefulTensor
],
**
kwargs
)
->
None
:
raise
NotImplementedError
class
CPUTensorPlacementPolicy
(
TensorPlacementPolicy
):
def
__init__
(
self
,
mem_stats_collector
:
Optional
[
MemStatsCollector
]
=
None
)
->
None
:
super
().
__init__
(
torch
.
device
(
'cpu'
),
mem_stats_collector
=
mem_stats_collector
)
def
evict_tensors
(
self
,
hold_cuda_tensor_list
:
List
[
StatefulTensor
],
**
kwargs
)
->
int
:
volume
=
0
for
t
in
hold_cuda_tensor_list
:
colo_model_data_tensor_move_inline
(
t
,
self
.
device
)
volume
+=
t
.
payload
.
numel
()
*
t
.
payload
.
element_size
()
return
volume
,
0
class
CUDATensorPlacementPolicy
(
TensorPlacementPolicy
):
def
__init__
(
self
,
mem_stats_collector
:
Optional
[
MemStatsCollector
]
=
None
)
->
None
:
assert
torch
.
cuda
.
is_available
(),
'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
super
().
__init__
(
get_current_device
(),
mem_stats_collector
=
mem_stats_collector
)
def
evict_tensors
(
self
,
hold_cuda_tensor_list
:
List
[
StatefulTensor
],
**
kwargs
)
->
int
:
return
0
,
0
class
AutoTensorPlacementPolicy
(
TensorPlacementPolicy
):
def
__init__
(
self
,
mem_stats_collector
:
Optional
[
MemStatsCollector
]
=
None
)
->
None
:
super
().
__init__
(
None
,
mem_stats_collector
=
mem_stats_collector
)
# model data will use 1-self._warmup_non_model_data_ratio CUDA memory in warmup phase
# TODO(ver217): make these args configurable
self
.
_warmup_non_model_data_ratio
:
float
=
0.8
self
.
_steady_cuda_cap_ratio
:
float
=
0.9
def
evict_tensors
(
self
,
hold_cuda_tensor_list
:
List
[
StatefulTensor
],
cuda_demand
:
int
=
0
,
warmup
:
bool
=
True
,
compute_list
:
List
[
StatefulTensor
]
=
[],
compute_idx
:
int
=
0
,
**
kwargs
)
->
int
:
"""
Evict tensors from CUDA device.
Args:
hold_cuda_tensor_list (List[StatefulTensor]): the list of tensor in state of HOLD-like
cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0.
warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True.
compute_list (List[StatefulTensor], optional): TODO. Defaults to [].
compute_idx (int, optional): the idx of computing device. Defaults to 0.
Raises:
RuntimeError:
Returns:
int: the volume of memory that is evicted
"""
start
=
time
()
cuda_capacity
=
colo_device_memory_capacity
(
get_current_device
())
used_cuda_model_data
=
StatefulTensor
.
GST_MGR
.
total_mem
[
'cuda'
]
if
warmup
:
# We designate a part of CUDA memory for model data in warmup iterations.
max_cuda_non_model_data_per_period
=
cuda_capacity
*
self
.
_warmup_non_model_data_ratio
else
:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period
=
self
.
mem_stats_collector
.
next_period_non_model_data_usage
(
'cuda'
)
cuda_capacity
*=
self
.
_steady_cuda_cap_ratio
total_cuda_model_data
=
cuda_capacity
-
max_cuda_non_model_data_per_period
avail_cuda_model_data
=
total_cuda_model_data
-
used_cuda_model_data
freed_cuda_model_data
=
0
end
=
time
()
if
avail_cuda_model_data
<
cuda_demand
:
# Move cuda_demand - avail_cuda_model_data volume of tensors
# to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
to_free_cuda_model_data
=
cuda_demand
-
avail_cuda_model_data
to_free_tensor_list
=
hold_cuda_tensor_list
if
not
warmup
:
to_free_tensor_list
=
self
.
_sort_hold_cuda_tensors
(
tuple
(
hold_cuda_tensor_list
),
compute_idx
,
tuple
(
compute_list
))
# print(self._sort_hold_cuda_tensors.cache_info())
end
=
time
()
for
t
in
to_free_tensor_list
:
if
freed_cuda_model_data
>=
to_free_cuda_model_data
:
break
freed_cuda_model_data
+=
t
.
payload_size
colo_model_data_tensor_move_inline
(
t
,
torch
.
device
(
'cpu'
))
if
freed_cuda_model_data
<
to_free_cuda_model_data
:
raise
RuntimeError
(
f
"Adjust layout failed! No enough CUDA memory! Need
{
to_free_cuda_model_data
}
, freed
{
freed_cuda_model_data
}
"
)
return
freed_cuda_model_data
,
end
-
start
@
staticmethod
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_sort_hold_cuda_tensors
(
hold_cuda_tensors
:
tuple
,
compute_idx
:
int
,
compute_list
:
tuple
)
->
list
:
next_compute_idx
=
{
t
:
len
(
compute_list
)
for
t
in
hold_cuda_tensors
}
for
i
in
range
(
len
(
compute_list
)
-
1
,
compute_idx
,
-
1
):
if
compute_list
[
i
]
in
next_compute_idx
:
next_compute_idx
[
compute_list
[
i
]]
=
i
next_compute_idx
=
sorted
(
next_compute_idx
.
items
(),
key
=
lambda
pair
:
pair
[
1
],
reverse
=
True
)
return
[
t
for
(
t
,
idx
)
in
next_compute_idx
]
class
TensorPlacementPolicyFactory
:
@
staticmethod
def
create
(
policy_name
:
str
)
->
Type
[
TensorPlacementPolicy
]:
if
policy_name
==
'cpu'
:
return
CPUTensorPlacementPolicy
elif
policy_name
==
'cuda'
:
return
CUDATensorPlacementPolicy
elif
policy_name
==
'auto'
:
return
AutoTensorPlacementPolicy
else
:
raise
TypeError
(
f
"Unknown tensor placement policy
{
policy_name
}
"
)
colossalai/gemini/tensor_utils.py
0 → 100644
View file @
08f2920e
import
torch
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
from
typing
import
Union
,
Tuple
def
is_storage_empty
(
tensor
:
torch
.
Tensor
)
->
bool
:
return
tensor
.
storage
().
size
()
==
0
def
free_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
not
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
0
)
def
alloc_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
tensor
.
numel
())
def
colo_tensor_mem_usage
(
tensor
:
Union
[
torch
.
Tensor
,
StatefulTensor
])
->
Tuple
[
int
,
int
]:
if
isinstance
(
tensor
,
StatefulTensor
):
t
=
tensor
.
payload
elif
isinstance
(
tensor
,
torch
.
Tensor
):
t
=
tensor
else
:
return
0
,
0
cuda_use
,
cpu_use
=
0
,
0
mem_use
=
t
.
storage
().
size
()
*
t
.
element_size
()
if
t
.
device
.
type
==
'cuda'
:
cuda_use
+=
mem_use
elif
t
.
device
.
type
==
'cpu'
:
cpu_use
+=
mem_use
return
cuda_use
,
cpu_use
def
colo_model_data_tensor_move
(
src_t
:
Union
[
StatefulTensor
,
torch
.
Tensor
],
tgt_t
:
Union
[
StatefulTensor
,
torch
.
Tensor
])
->
None
:
"""
A colossal API for model data tensor move.
The src and target tensors could be resident on both CPU and GPU.
NOTE() The source tensor payload will be removed after this function.
The function will record the communication volume between CPU and GPU.
Args:
src_t (Union[StatefulTensor, torch.Tensor]): source tensor
tgt_t (Union[StatefulTensor, torch.Tensor]): target tensor
"""
if
isinstance
(
src_t
,
StatefulTensor
):
src_t_payload
=
src_t
.
payload
else
:
src_t_payload
=
src_t
.
data
src_dev
=
src_t_payload
.
device
if
isinstance
(
tgt_t
,
StatefulTensor
):
tgt_t_payload
=
tgt_t
.
payload
else
:
tgt_t_payload
=
tgt_t
.
data
tgt_t_payload
.
copy_
(
src_t_payload
)
# remove payload of src_t
if
isinstance
(
src_t
,
StatefulTensor
):
src_t
.
set_null
()
else
:
src_t
.
data
=
torch
.
empty
(
0
,
device
=
src_dev
,
dtype
=
src_t_payload
.
dtype
)
def
colo_model_data_tensor_move_inline
(
t
:
Union
[
StatefulTensor
,
torch
.
Tensor
],
target_device
:
Union
[
torch
.
device
,
int
])
->
None
:
"""
move a tensor to the target_device
Args:
t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
target_device: a traget device, if type is int, it the index of cuda card.
"""
if
not
isinstance
(
target_device
,
torch
.
device
):
target_device
=
torch
.
device
(
f
'cuda:
{
target_device
}
'
)
if
isinstance
(
t
,
torch
.
Tensor
):
t
.
data
=
t
.
data
.
to
(
target_device
)
elif
isinstance
(
t
,
StatefulTensor
):
t
.
move_to
(
target_device
)
else
:
raise
TypeError
(
f
'colo_model_data_tensor_move_inline dose not accept type
{
type
(
t
)
}
'
)
def
colo_model_data_move_to_cpu
(
t
:
Union
[
StatefulTensor
,
torch
.
Tensor
])
->
None
:
"""colo_model_data_move_to_cpu
move a model data tensor from gpu to cpu
Args:
t (Union[StatefulTensor, torch.Tensor]): _description_
"""
# TODO() optimize the tensor moving with non-blocking
if
isinstance
(
t
,
torch
.
Tensor
):
t
.
data
=
t
.
data
.
cpu
()
elif
isinstance
(
t
,
StatefulTensor
):
t
.
move_to
(
torch
.
device
(
'cpu'
))
else
:
raise
TypeError
(
f
'colo_model_data_move_to_cpu dose not accept type
{
type
(
t
)
}
'
)
def
colo_model_tensor_clone
(
t
:
Union
[
StatefulTensor
,
torch
.
Tensor
],
target_device
:
torch
.
device
)
->
torch
.
Tensor
:
"""
Clone a model data tensor
Args:
t (Union[StatefulTensor, torch.Tensor]): a model data tensor
target_device (torch.device): the target device
Returns:
torch.Tensor: a cloned torch tensor
"""
# TODO() rename this function
colo_model_data_tensor_move_inline
(
t
,
target_device
)
t_payload
=
t
.
payload
if
isinstance
(
t
,
StatefulTensor
)
else
t
return
t_payload
colossalai/global_variables.py
0 → 100644
View file @
08f2920e
from
typing
import
Optional
class
TensorParallelEnv
(
object
):
_instance
=
None
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
object
.
__new__
(
cls
,
*
args
,
**
kwargs
)
return
cls
.
_instance
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
load
(
*
args
,
**
kwargs
)
def
load
(
self
,
mode
:
Optional
[
str
]
=
None
,
vocab_parallel
:
bool
=
False
,
parallel_input_1d
:
bool
=
False
,
summa_dim
:
int
=
None
,
tesseract_dim
:
int
=
None
,
tesseract_dep
:
int
=
None
,
depth_3d
:
int
=
None
,
input_group_3d
=
None
,
weight_group_3d
=
None
,
output_group_3d
=
None
,
input_x_weight_group_3d
=
None
,
output_x_weight_group_3d
=
None
):
self
.
mode
=
mode
self
.
vocab_parallel
=
vocab_parallel
self
.
parallel_input_1d
=
parallel_input_1d
self
.
summa_dim
=
summa_dim
self
.
tesseract_dim
=
tesseract_dim
self
.
tesseract_dep
=
tesseract_dep
self
.
depth_3d
=
depth_3d
self
.
input_group_3d
=
input_group_3d
self
.
weight_group_3d
=
weight_group_3d
self
.
output_group_3d
=
output_group_3d
self
.
input_x_weight_group_3d
=
input_x_weight_group_3d
self
.
output_x_weight_group_3d
=
output_x_weight_group_3d
def
save
(
self
):
return
dict
(
mode
=
self
.
mode
,
vocab_parallel
=
self
.
vocab_parallel
,
parallel_input_1d
=
self
.
parallel_input_1d
,
summa_dim
=
self
.
summa_dim
,
tesseract_dim
=
self
.
tesseract_dim
,
tesseract_dep
=
self
.
tesseract_dep
,
depth_3d
=
self
.
depth_3d
,
input_group_3d
=
self
.
input_group_3d
,
weight_group_3d
=
self
.
weight_group_3d
,
output_group_3d
=
self
.
output_group_3d
,
input_x_weight_group_3d
=
self
.
input_x_weight_group_3d
,
output_x_weight_group_3d
=
self
.
output_x_weight_group_3d
)
tensor_parallel_env
=
TensorParallelEnv
()
colossalai/initialize.py
0 → 100644
View file @
08f2920e
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
argparse
import
os
import
pprint
from
pathlib
import
Path
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
from
torch.nn.modules.loss
import
_Loss
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.optim.lr_scheduler
import
_LRScheduler
from
torch.optim.optimizer
import
Optimizer
from
torch.utils.data
import
DataLoader
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.logging
import
get_dist_logger
from
colossalai.engine.schedule
import
NonPipelineSchedule
,
PipelineSchedule
,
InterleavedPipelineSchedule
,
get_tensor_shape
from
colossalai.engine
import
Engine
from
colossalai.gemini.ophooks
import
BaseOpHook
from
colossalai.utils
import
(
get_current_device
,
is_using_ddp
,
is_using_pp
,
is_using_sequence
,
sync_model_param
)
from
colossalai.utils.moe
import
sync_moe_model_param
from
colossalai.amp
import
AMP_TYPE
,
convert_to_amp
from
colossalai.amp.naive_amp
import
NaiveAMPModel
from
colossalai.builder.builder
import
build_gradient_handler
from
colossalai.context
import
Config
,
ConfigException
,
ParallelMode
from
colossalai.engine.gradient_accumulation
import
accumulate_gradient
from
colossalai.nn.optimizer.colossalai_optimizer
import
ColossalaiOptimizer
from
colossalai.zero
import
convert_to_zero_v2
from
colossalai.zero.sharded_optim.sharded_optim_v2
import
ShardedOptimizerV2
def
get_default_parser
():
"""Reads user command line and uses an argument parser to parse the input arguments.
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
Returns:
Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser.
"""
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--config'
,
type
=
str
,
help
=
'path to the config file'
)
parser
.
add_argument
(
'--host'
,
type
=
str
,
help
=
'the master address for distributed training'
)
parser
.
add_argument
(
'--port'
,
type
=
int
,
help
=
'the master port for distributed training'
)
parser
.
add_argument
(
'--world_size'
,
type
=
int
,
help
=
'world size for distributed training'
)
parser
.
add_argument
(
'--rank'
,
type
=
int
,
help
=
'rank for the default process group'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
help
=
'local rank on the node'
)
parser
.
add_argument
(
'--backend'
,
type
=
str
,
default
=
'nccl'
,
help
=
'backend for distributed communication'
)
return
parser
def
launch
(
config
:
Union
[
str
,
Path
,
Config
,
Dict
],
rank
:
int
,
world_size
:
int
,
host
:
str
,
port
:
int
,
backend
:
str
=
'nccl'
,
local_rank
:
int
=
None
,
seed
:
int
=
1024
,
verbose
:
bool
=
True
):
"""This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
arguments are not given. Then initialize and set distributed environment by calling global_context's functions.
Args:
config (Union[str, dict, Config]): Config file or config file path are both acceptable
rank (int): Rank for the default process group
world_size (int): World size of the default process group
host (str): The master address for distributed training
port (str): The master port for distributed training
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
local_rank (int, optional):
Rank for the process on the node and is used to set the default CUDA device,
defaults to None. If local_rank = None, the default device ordinal will be calculated automatically.
seed (int, optional): Specified random seed for every process. Defaults to 1024.
verbose (bool, optional): Whether to print logs. Defaults to True.
Raises:
Exception: Raise exception when config type is wrong
"""
gpc
.
verbose
=
verbose
# set config
assert
isinstance
(
config
,
(
Config
,
str
,
Path
,
dict
)),
\
f
'expected argument config to be Config, str or Path, but got
{
type
(
config
)
}
'
if
not
isinstance
(
config
,
Config
)
and
isinstance
(
config
,
dict
):
config
=
Config
(
config
)
if
isinstance
(
config
,
(
str
,
Path
)):
config
=
Config
.
from_file
(
config
)
gpc
.
load_config
(
config
)
# init default process group
gpc
.
init_global_dist
(
rank
,
world_size
,
backend
,
host
,
port
)
# init process groups for different parallel modes from config
gpc
.
init_parallel_groups
()
# set cuda device
if
torch
.
cuda
.
is_available
():
# if local rank is not given, calculate automatically
gpc
.
set_device
(
local_rank
)
# set the number of processes running on the same node
gpc
.
detect_num_processes_on_current_node
()
gpc
.
set_seed
(
seed
)
if
verbose
:
logger
=
get_dist_logger
()
logger
.
info
(
f
'Distributed environment is initialized, '
f
'data parallel size:
{
gpc
.
data_parallel_size
}
, pipeline parallel size:
{
gpc
.
pipeline_parallel_size
}
, '
f
'tensor parallel size:
{
gpc
.
tensor_parallel_size
}
'
,
ranks
=
[
0
])
def
launch_from_slurm
(
config
:
Union
[
str
,
Path
,
Config
,
Dict
],
host
:
str
,
port
:
int
,
backend
:
str
=
'nccl'
,
seed
:
int
=
1024
,
verbose
:
bool
=
True
):
"""A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables
set by SLURM
Args:
config (Union[str, dict, Config]): Config file or config file path are both acceptable
host (str): The master address for distributed training
port (str): The master port for distributed training
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
seed (int, optional): Specified random seed for every process. Defaults to 1024.
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
try
:
rank
=
int
(
os
.
environ
[
'SLURM_PROCID'
])
world_size
=
int
(
os
.
environ
[
'SLURM_NPROCS'
])
except
KeyError
as
e
:
raise
RuntimeError
(
f
"Could not find
{
e
}
in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM"
)
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
host
,
port
=
port
,
backend
=
backend
,
seed
=
seed
,
verbose
=
verbose
)
def
launch_from_openmpi
(
config
:
Union
[
str
,
Path
,
Config
,
Dict
],
host
:
str
,
port
:
int
,
backend
:
str
=
'nccl'
,
seed
:
int
=
1024
,
verbose
:
bool
=
True
):
"""A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables
set by OpenMPI
Args:
config (Union[str, dict, Config]): Config file or config file path are both acceptable
host (str): The master address for distributed training
port (str): The master port for distributed training
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
seed (int, optional): Specified random seed for every process. Defaults to 1024.
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
try
:
rank
=
int
(
os
.
environ
[
'OMPI_COMM_WORLD_RANK'
])
local_rank
=
int
(
os
.
environ
[
'OMPI_COMM_WORLD_LOCAL_RANK'
])
world_size
=
int
(
os
.
environ
[
'OMPI_COMM_WORLD_SIZE'
])
except
KeyError
as
e
:
raise
RuntimeError
(
f
"Could not find
{
e
}
in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI"
)
launch
(
config
=
config
,
local_rank
=
local_rank
,
rank
=
rank
,
world_size
=
world_size
,
host
=
host
,
port
=
port
,
backend
=
backend
,
seed
=
seed
,
verbose
=
verbose
)
def
launch_from_torch
(
config
:
Union
[
str
,
Path
,
Config
,
Dict
],
backend
:
str
=
'nccl'
,
seed
:
int
=
1024
,
verbose
:
bool
=
True
):
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
from the environment variables set by PyTorch
Args:
config (Union[str, dict, Config]): Config file or config file path are both acceptable
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
seed (int, optional): Specified random seed for every process. Defaults to 1024.
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
try
:
rank
=
int
(
os
.
environ
[
'RANK'
])
local_rank
=
int
(
os
.
environ
[
'LOCAL_RANK'
])
world_size
=
int
(
os
.
environ
[
'WORLD_SIZE'
])
host
=
os
.
environ
[
'MASTER_ADDR'
]
port
=
int
(
os
.
environ
[
'MASTER_PORT'
])
except
KeyError
as
e
:
raise
RuntimeError
(
f
"Could not find
{
e
}
in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
)
launch
(
config
=
config
,
local_rank
=
local_rank
,
rank
=
rank
,
world_size
=
world_size
,
host
=
host
,
port
=
port
,
backend
=
backend
,
seed
=
seed
,
verbose
=
verbose
)
def
initialize
(
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
criterion
:
Optional
[
_Loss
]
=
None
,
train_dataloader
:
Optional
[
Iterable
]
=
None
,
test_dataloader
:
Optional
[
Iterable
]
=
None
,
lr_scheduler
:
Optional
[
_LRScheduler
]
=
None
,
ophooks
:
Optional
[
List
[
BaseOpHook
]]
=
None
,
verbose
:
bool
=
True
)
->
Tuple
[
Engine
,
DataLoader
,
DataLoader
,
_LRScheduler
]:
"""Core function to wrap the essential training components with our functionality based on the config which is
loaded into gpc.config.
Args:
model (:class:`torch.nn.Module` or Callbale): Your model instance or a function to build the model.
optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`):
Your optimizer instance.
criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
verbose (bool, optional): Whether to print logs.
Returns:
Tuple (engine, train_dataloader, test_dataloader, lr_scheduler):
A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)``
where only ``engine`` could not be None.
"""
# get logger
logger
=
get_dist_logger
()
gpc
.
verbose
=
verbose
# get config from gpc
config
=
gpc
.
config
# print config
if
verbose
:
logger
.
info
(
f
"
\n
========== Your Config ========
\n
"
f
"
{
pprint
.
pformat
(
gpc
.
config
)
}
\n
"
f
"================================
\n
"
,
ranks
=
[
0
])
# cudnn
cudnn_benchmark
=
config
.
get
(
'cudnn_benchmark'
,
False
)
cudnn_deterministic
=
config
.
get
(
'cudnn_deterministic'
,
False
)
torch
.
backends
.
cudnn
.
benchmark
=
cudnn_benchmark
torch
.
backends
.
cudnn
.
deterministic
=
cudnn_deterministic
if
verbose
:
logger
.
info
(
f
"cuDNN benchmark =
{
cudnn_benchmark
}
, deterministic =
{
cudnn_deterministic
}
"
,
ranks
=
[
0
])
# zero
use_zero
=
hasattr
(
gpc
.
config
,
'zero'
)
if
use_zero
:
zero_cfg
=
gpc
.
config
.
get
(
'zero'
,
None
)
if
zero_cfg
is
not
None
:
cfg_
=
zero_cfg
.
copy
()
else
:
cfg_
=
{}
optimizer_config
=
zero_cfg
.
get
(
'optimizer_config'
,
None
)
model_config
=
zero_cfg
.
get
(
'model_config'
,
None
)
model
,
optimizer
=
convert_to_zero_v2
(
model
,
optimizer
,
model_config
=
model_config
,
optimizer_config
=
optimizer_config
)
logger
.
info
(
"Initializing ZeRO model and optimizer finished!"
,
ranks
=
[
0
])
else
:
if
isinstance
(
model
,
nn
.
Module
):
# first sync model across dp ranks
model
.
to
(
get_current_device
())
elif
isinstance
(
model
,
Callable
):
model
=
model
().
to
(
get_current_device
())
# optimizer maybe a optimizer_cls
logger
.
warning
(
"Initializing an non ZeRO model with optimizer class"
)
if
isinstance
(
optimizer
,
Callable
):
optimizer
=
optimizer
(
model
.
parameters
())
if
not
use_zero
:
if
is_using_sequence
():
sync_model_param
(
model
,
ParallelMode
.
SEQUENCE_DP
)
elif
MOE_CONTEXT
.
is_initialized
:
sync_moe_model_param
(
model
)
elif
is_using_ddp
():
sync_model_param
(
model
,
ParallelMode
.
DATA
)
else
:
logger
.
warning
(
"The parameters of models is not automatically synchronized.
\n
"
"Please make sure that all parameters are the same in data parallel group."
,
ranks
=
[
0
])
# check amp and zero
fp16_cfg
=
gpc
.
config
.
get
(
'fp16'
,
None
)
if
fp16_cfg
is
not
None
and
fp16_cfg
.
mode
is
not
None
and
use_zero
:
raise
ConfigException
(
"It is not allowed to set fp16 and zero configuration in your config file at the same time"
)
# clip grad norm
clip_grad_norm
=
gpc
.
config
.
get
(
'clip_grad_norm'
,
0.0
)
# initialize amp
amp_mode
=
None
if
fp16_cfg
is
not
None
and
fp16_cfg
.
mode
is
not
None
:
cfg_
=
fp16_cfg
.
copy
()
amp_mode
=
cfg_
.
pop
(
'mode'
)
if
is_using_pp
():
assert
amp_mode
==
AMP_TYPE
.
NAIVE
,
'Pipeline only support NaiveAMP currently'
if
amp_mode
==
AMP_TYPE
.
NAIVE
:
cfg_
[
'clip_grad_norm'
]
=
clip_grad_norm
model
,
optimizer
,
criterion
=
convert_to_amp
(
model
=
model
,
optimizer
=
optimizer
,
criterion
=
criterion
,
mode
=
amp_mode
,
amp_config
=
cfg_
)
# get torch ddp config
torch_ddp_cfg
=
gpc
.
config
.
get
(
'torch_ddp'
,
dict
())
# gradient handler
gradient_handler_cfg
=
gpc
.
config
.
get
(
'gradient_handler'
,
None
)
if
gradient_handler_cfg
is
None
:
# if gradient handler is not specified in the configuration file,
# check in the following order
# 1. if optimizer is ZERO, then use zero grad handler
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
if
isinstance
(
optimizer
,
ShardedOptimizerV2
):
gradient_handler_cfg
=
[
dict
(
type
=
'ZeROGradientHandler'
)]
if
verbose
:
logger
.
info
(
"Training with zero is detected, ZeROGradientHandler is automatically "
"added even though not specified in the configuration"
,
ranks
=
[
0
])
elif
is_using_ddp
()
and
MOE_CONTEXT
.
is_initialized
:
gradient_handler_cfg
=
[
dict
(
type
=
'MoeGradientHandler'
)]
if
verbose
:
logger
.
info
(
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
"added even though not specified in the configuration"
,
ranks
=
[
0
])
elif
is_using_sequence
():
model
=
DDP
(
model
,
process_group
=
gpc
.
get_group
(
ParallelMode
.
SEQUENCE_DP
),
device_ids
=
[
torch
.
cuda
.
current_device
()],
**
torch_ddp_cfg
)
if
verbose
:
logger
.
info
(
'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism'
,
ranks
=
[
0
])
elif
is_using_ddp
()
and
not
is_using_pp
()
and
amp_mode
!=
AMP_TYPE
.
NAIVE
:
model
=
DDP
(
model
,
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
),
device_ids
=
[
torch
.
cuda
.
current_device
()],
**
torch_ddp_cfg
)
if
verbose
:
logger
.
info
(
'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism'
,
ranks
=
[
0
])
elif
is_using_ddp
():
gradient_handler_cfg
=
[
dict
(
type
=
'DataParallelGradientHandler'
)]
if
verbose
:
logger
.
info
(
"Data parallel training is detected when using pipeline parallel, "
"DataParallelGradientHandler is automatically "
"added even though not specified in the configuration"
,
ranks
=
[
0
])
# add pipeline parallel gradient handler, if pipeline shared module is detected
for
param
in
model
.
parameters
():
if
getattr
(
param
,
'pipeline_shared_module_pg'
,
None
)
is
not
None
:
if
gradient_handler_cfg
is
None
:
gradient_handler_cfg
=
[
dict
(
type
=
'PipelineSharedModuleGradientHandler'
)]
else
:
gradient_handler_cfg
.
append
(
dict
(
type
=
'PipelineSharedModuleGradientHandler'
))
if
verbose
:
logger
.
info
(
"pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically "
"added even though not specified in the configuration"
,
ranks
=
[
0
])
break
else
:
if
not
isinstance
(
gradient_handler_cfg
,
list
):
raise
ConfigException
(
f
"expected gradient_handler in the configuration file to be a list but got
{
type
(
gradient_handler_cfg
)
}
"
)
# turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time
# to avoid duplicated buffer synchronization
if
isinstance
(
model
,
DDP
)
and
isinstance
(
model
.
module
,
NaiveAMPModel
):
model
.
module
.
sync_buffer
=
False
# initialize schedule for engine
if
is_using_pp
():
tensor_shape
=
get_tensor_shape
()
use_interleaved
=
hasattr
(
gpc
.
config
,
'model'
)
and
hasattr
(
gpc
.
config
.
model
,
'num_chunks'
)
if
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_1D
):
scatter_gather
=
True
else
:
scatter_gather
=
False
if
use_interleaved
:
if
isinstance
(
model
,
nn
.
Sequential
):
model
=
nn
.
ModuleList
([
model
])
schedule
=
InterleavedPipelineSchedule
(
gpc
.
config
.
NUM_MICRO_BATCHES
,
gpc
.
config
.
model
.
num_chunks
,
tensor_shape
=
tensor_shape
,
scatter_gather_tensors
=
scatter_gather
)
else
:
schedule
=
PipelineSchedule
(
gpc
.
config
.
NUM_MICRO_BATCHES
,
tensor_shape
=
tensor_shape
,
scatter_gather_tensors
=
scatter_gather
)
else
:
schedule
=
NonPipelineSchedule
()
if
gradient_handler_cfg
is
None
:
gradient_handlers
=
None
if
verbose
and
not
isinstance
(
model
,
DDP
):
logger
.
warning
(
"No PyTorch DDP or gradient handler is set up, please make sure you do not need "
"to all-reduce the gradients after a training step."
,
ranks
=
[
0
])
else
:
gradient_handlers
=
[
build_gradient_handler
(
cfg
,
model
,
optimizer
)
for
cfg
in
gradient_handler_cfg
]
# check if optimizer is ColossalaiOptimizer
if
not
isinstance
(
optimizer
,
(
ColossalaiOptimizer
,
ShardedOptimizerV2
)):
optimizer
=
ColossalaiOptimizer
(
optim
=
optimizer
)
# gradient accumulation
grad_accum_size
=
gpc
.
config
.
get
(
'gradient_accumulation'
,
None
)
if
grad_accum_size
is
not
None
:
optimizer
,
train_dataloader
,
gradient_handlers
,
lr_scheduler
=
accumulate_gradient
(
model
=
model
,
optimizer
=
optimizer
,
dataloader
=
train_dataloader
,
accumulate_size
=
grad_accum_size
,
gradient_handlers
=
gradient_handlers
,
lr_scheduler
=
lr_scheduler
)
engine
=
Engine
(
model
=
model
,
optimizer
=
optimizer
,
criterion
=
criterion
,
gradient_handlers
=
gradient_handlers
,
clip_grad_norm
=
clip_grad_norm
,
ophook_list
=
ophooks
,
schedule
=
schedule
)
return
engine
,
train_dataloader
,
test_dataloader
,
lr_scheduler
colossalai/kernel/__init__.py
0 → 100644
View file @
08f2920e
from
.cuda_native
import
LayerNorm
,
FusedScaleMaskSoftmax
,
MultiHeadAttention
__all__
=
[
"LayerNorm"
,
"FusedScaleMaskSoftmax"
,
"MultiHeadAttention"
]
colossalai/kernel/cuda_native/__init__.py
0 → 100644
View file @
08f2920e
from
.layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
from
.multihead_attention
import
MultiHeadAttention
from
.scaled_softmax
import
FusedScaleMaskSoftmax
colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp
0 → 100644
View file @
08f2920e
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
#include <torch/extension.h>
void
multi_tensor_scale_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
scale
);
void
multi_tensor_sgd_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
wd
,
float
momentum
,
float
dampening
,
float
lr
,
bool
nesterov
,
bool
first_run
,
bool
wd_after_momentum
,
float
scale
);
void
multi_tensor_adam_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
float
div_scale
);
void
multi_tensor_lamb_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
grad_averaging
,
const
int
mode
,
at
::
Tensor
global_grad_norm
,
const
float
max_grad_norm
,
at
::
optional
<
bool
>
use_nvlamb_python
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
multi_tensor_l2norm_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
optional
<
bool
>
per_tensor_python
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"multi_tensor_scale"
,
&
multi_tensor_scale_cuda
,
"Fused overflow check + scale for a list of contiguous tensors"
);
m
.
def
(
"multi_tensor_sgd"
,
&
multi_tensor_sgd_cuda
,
"Fused SGD optimizer for list of contiguous tensors"
);
m
.
def
(
"multi_tensor_adam"
,
&
multi_tensor_adam_cuda
,
"Compute and apply gradient update to parameters for Adam optimizer"
);
m
.
def
(
"multi_tensor_lamb"
,
&
multi_tensor_lamb_cuda
,
"Computes and apply update for LAMB optimizer"
);
m
.
def
(
"multi_tensor_l2norm"
,
&
multi_tensor_l2norm_cuda
,
"Computes L2 norm for a list of contiguous tensors"
);
}
colossalai/kernel/cuda_native/csrc/compat.h
0 → 100644
View file @
08f2920e
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
\ No newline at end of file
colossalai/kernel/cuda_native/csrc/cpu_adam.cpp
0 → 100644
View file @
08f2920e
/*
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
*/
#include "cpu_adam.h"
#include <math.h>
#include <omp.h>
#include <string.h>
#include <iostream>
#include <memory>
#include <type_traits>
#include <unordered_map>
// C++ interface
void
Adam_Optimizer
::
Step_1
(
float
*
_params
,
float
*
grads
,
float
*
_exp_avg
,
float
*
_exp_avg_sq
,
size_t
_param_size
,
bool
param_half_precision
,
bool
grad_half_precision
,
float
loss_scale
)
{
size_t
rounded_size
=
0
;
float
betta1_minus1
=
1
-
_betta1
;
float
betta2_minus1
=
1
-
_betta2
;
float
step_size
=
-
1
*
_alpha
/
_bias_correction1
;
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
__half
*
params_cast_h
=
NULL
;
__half
*
grads_cast_h
=
NULL
;
if
(
param_half_precision
)
{
params_cast_h
=
reinterpret_cast
<
__half
*>
(
_params
);
}
if
(
grad_half_precision
)
{
grads_cast_h
=
reinterpret_cast
<
__half
*>
(
grads
);
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data
betta1_4
;
betta1_4
.
data
=
SIMD_SET
(
_betta1
);
AVX_Data
betta2_4
;
betta2_4
.
data
=
SIMD_SET
(
_betta2
);
AVX_Data
betta1_minus1_4
;
betta1_minus1_4
.
data
=
SIMD_SET
(
betta1_minus1
);
AVX_Data
betta2_minus1_4
;
betta2_minus1_4
.
data
=
SIMD_SET
(
betta2_minus1
);
AVX_Data
bias2_sqrt
;
bias2_sqrt
.
data
=
SIMD_SET
(
_bias_correction2
);
AVX_Data
eps_4
;
eps_4
.
data
=
SIMD_SET
(
_eps
);
AVX_Data
step_size_4
;
step_size_4
.
data
=
SIMD_SET
(
step_size
);
AVX_Data
weight_decay_4
;
if
(
_weight_decay
>
0
)
weight_decay_4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
rounded_size
=
ROUND_DOWN
(
_param_size
,
SIMD_WIDTH
);
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
#pragma omp parallel for
for
(
size_t
i
=
t
;
i
<
offset
;
i
+=
SIMD_WIDTH
)
{
AVX_Data
grad_4
;
if
(
grad_half_precision
)
{
grad_4
.
data
=
SIMD_LOAD_HALF
(
grads_cast_h
+
i
);
}
else
{
grad_4
.
data
=
SIMD_LOAD
(
grads
+
i
);
}
if
(
loss_scale
>
0
)
{
AVX_Data
loss_scale_vec
;
loss_scale_vec
.
data
=
SIMD_SET
(
loss_scale
);
grad_4
.
data
=
SIMD_DIV
(
grad_4
.
data
,
loss_scale_vec
.
data
);
}
AVX_Data
momentum_4
;
momentum_4
.
data
=
SIMD_LOAD
(
_exp_avg
+
i
);
AVX_Data
variance_4
;
variance_4
.
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
);
AVX_Data
param_4
;
if
(
param_half_precision
)
{
param_4
.
data
=
SIMD_LOAD_HALF
(
params_cast_h
+
i
);
}
else
{
param_4
.
data
=
SIMD_LOAD
(
_params
+
i
);
}
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad_4
.
data
=
SIMD_FMA
(
param_4
.
data
,
weight_decay_4
.
data
,
grad_4
.
data
);
}
momentum_4
.
data
=
SIMD_MUL
(
momentum_4
.
data
,
betta1_4
.
data
);
momentum_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
betta1_minus1_4
.
data
,
momentum_4
.
data
);
variance_4
.
data
=
SIMD_MUL
(
variance_4
.
data
,
betta2_4
.
data
);
grad_4
.
data
=
SIMD_MUL
(
grad_4
.
data
,
grad_4
.
data
);
variance_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
betta2_minus1_4
.
data
,
variance_4
.
data
);
grad_4
.
data
=
SIMD_SQRT
(
variance_4
.
data
);
grad_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
.
data
=
SIMD_DIV
(
momentum_4
.
data
,
grad_4
.
data
);
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param_4
.
data
=
SIMD_FMA
(
param_4
.
data
,
weight_decay_4
.
data
,
param_4
.
data
);
}
param_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
step_size_4
.
data
,
param_4
.
data
);
if
(
param_half_precision
)
{
SIMD_STORE_HALF
((
float
*
)(
params_cast_h
+
i
),
param_4
.
data
);
}
else
{
SIMD_STORE
(
_params
+
i
,
param_4
.
data
);
}
SIMD_STORE
(
_exp_avg
+
i
,
momentum_4
.
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
,
variance_4
.
data
);
}
}
#endif
if
(
_param_size
>
rounded_size
)
{
for
(
size_t
t
=
rounded_size
;
t
<
_param_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
_param_size
)
copy_size
=
_param_size
-
t
;
size_t
offset
=
copy_size
+
t
;
#pragma omp parallel for
for
(
size_t
k
=
t
;
k
<
offset
;
k
++
)
{
float
grad
=
grad_half_precision
?
(
float
)
grads_cast_h
[
k
]
:
grads
[
k
];
if
(
loss_scale
>
0
)
{
grad
/=
loss_scale
;
}
float
param
=
param_half_precision
?
(
float
)
params_cast_h
[
k
]
:
_params
[
k
];
float
momentum
=
_exp_avg
[
k
];
float
variance
=
_exp_avg_sq
[
k
];
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad
=
param
*
_weight_decay
+
grad
;
}
momentum
=
momentum
*
_betta1
;
momentum
=
grad
*
betta1_minus1
+
momentum
;
variance
=
variance
*
_betta2
;
grad
=
grad
*
grad
;
variance
=
grad
*
betta2_minus1
+
variance
;
grad
=
sqrt
(
variance
);
grad
=
grad
*
_bias_correction2
+
_eps
;
grad
=
momentum
/
grad
;
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param
+=
w_decay
*
param
;
}
param
=
grad
*
step_size
+
param
;
if
(
param_half_precision
)
params_cast_h
[
k
]
=
(
__half
)
param
;
else
_params
[
k
]
=
param
;
_exp_avg
[
k
]
=
momentum
;
_exp_avg_sq
[
k
]
=
variance
;
}
}
}
}
void
Adam_Optimizer
::
Step_4
(
float
*
_params
,
float
*
grads
,
float
*
_exp_avg
,
float
*
_exp_avg_sq
,
size_t
_param_size
,
bool
param_half_precision
,
bool
grad_half_precision
,
float
loss_scale
)
{
size_t
rounded_size
=
0
;
__half
*
params_cast_h
=
NULL
;
__half
*
grads_cast_h
=
NULL
;
if
(
param_half_precision
)
{
params_cast_h
=
reinterpret_cast
<
__half
*>
(
_params
);
}
if
(
grad_half_precision
)
{
grads_cast_h
=
reinterpret_cast
<
__half
*>
(
grads
);
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data
betta1_4
;
betta1_4
.
data
=
SIMD_SET
(
_betta1
);
AVX_Data
betta2_4
;
betta2_4
.
data
=
SIMD_SET
(
_betta2
);
float
betta1_minus1
=
1
-
_betta1
;
AVX_Data
betta1_minus1_4
;
betta1_minus1_4
.
data
=
SIMD_SET
(
betta1_minus1
);
float
betta2_minus1
=
1
-
_betta2
;
AVX_Data
betta2_minus1_4
;
betta2_minus1_4
.
data
=
SIMD_SET
(
betta2_minus1
);
AVX_Data
bias2_sqrt
;
bias2_sqrt
.
data
=
SIMD_SET
(
_bias_correction2
);
AVX_Data
eps_4
;
eps_4
.
data
=
SIMD_SET
(
_eps
);
float
step_size
=
-
1
*
_alpha
/
_bias_correction1
;
AVX_Data
step_size_4
;
step_size_4
.
data
=
SIMD_SET
(
step_size
);
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
AVX_Data
weight_decay_4
;
if
(
_weight_decay
>
0
)
weight_decay_4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
rounded_size
=
ROUND_DOWN
(
_param_size
,
SIMD_WIDTH
*
4
);
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
#pragma omp parallel for
for
(
size_t
i
=
t
;
i
<
offset
;
i
+=
SIMD_WIDTH
*
4
)
{
AVX_Data
grad_4
[
4
];
AVX_Data
momentum_4
[
4
];
AVX_Data
variance_4
[
4
];
AVX_Data
param_4
[
4
];
#pragma unroll 4
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
if
(
grad_half_precision
)
{
grad_4
[
j
].
data
=
SIMD_LOAD_HALF
(
grads_cast_h
+
i
+
SIMD_WIDTH
*
j
);
}
else
{
grad_4
[
j
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
*
j
);
}
if
(
loss_scale
>
0
)
{
AVX_Data
loss_scale_vec
;
loss_scale_vec
.
data
=
SIMD_SET
(
loss_scale
);
grad_4
[
j
].
data
=
SIMD_DIV
(
grad_4
[
j
].
data
,
loss_scale_vec
.
data
);
}
momentum_4
[
j
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
*
j
);
variance_4
[
j
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
j
);
if
(
param_half_precision
)
{
param_4
[
j
].
data
=
SIMD_LOAD_HALF
(
params_cast_h
+
i
+
SIMD_WIDTH
*
j
);
}
else
{
param_4
[
j
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
*
j
);
}
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
grad_4
[
j
].
data
);
}
momentum_4
[
j
].
data
=
SIMD_MUL
(
momentum_4
[
j
].
data
,
betta1_4
.
data
);
momentum_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
j
].
data
);
variance_4
[
j
].
data
=
SIMD_MUL
(
variance_4
[
j
].
data
,
betta2_4
.
data
);
grad_4
[
j
].
data
=
SIMD_MUL
(
grad_4
[
j
].
data
,
grad_4
[
j
].
data
);
variance_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_SQRT
(
variance_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
j
].
data
=
SIMD_DIV
(
momentum_4
[
j
].
data
,
grad_4
[
j
].
data
);
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
param_4
[
j
].
data
);
}
param_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
step_size_4
.
data
,
param_4
[
j
].
data
);
if
(
param_half_precision
)
{
SIMD_STORE_HALF
((
float
*
)(
params_cast_h
+
i
+
SIMD_WIDTH
*
j
),
param_4
[
j
].
data
);
}
else
{
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
j
,
param_4
[
j
].
data
);
}
SIMD_STORE
(
_exp_avg
+
i
+
SIMD_WIDTH
*
j
,
momentum_4
[
j
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
j
,
variance_4
[
j
].
data
);
}
}
}
#endif
if
(
_param_size
>
rounded_size
)
Step_1
((
param_half_precision
?
(
float
*
)(
params_cast_h
+
rounded_size
)
:
_params
+
rounded_size
),
(
grad_half_precision
?
(
float
*
)(
grads_cast_h
+
rounded_size
)
:
grads
+
rounded_size
),
(
_exp_avg
+
rounded_size
),
(
_exp_avg_sq
+
rounded_size
),
(
_param_size
-
rounded_size
),
param_half_precision
,
grad_half_precision
,
loss_scale
);
}
void
Adam_Optimizer
::
Step_8
(
float
*
_params
,
float
*
grads
,
float
*
_exp_avg
,
float
*
_exp_avg_sq
,
size_t
_param_size
,
bool
param_half_precision
,
bool
grad_half_precision
,
float
loss_scale
)
{
size_t
rounded_size
=
0
;
__half
*
params_cast_h
=
NULL
;
__half
*
grads_cast_h
=
NULL
;
if
(
param_half_precision
)
{
params_cast_h
=
reinterpret_cast
<
__half
*>
(
_params
);
}
if
(
grad_half_precision
)
{
grads_cast_h
=
reinterpret_cast
<
__half
*>
(
grads
);
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data
betta1_4
;
betta1_4
.
data
=
SIMD_SET
(
_betta1
);
AVX_Data
betta2_4
;
betta2_4
.
data
=
SIMD_SET
(
_betta2
);
float
betta1_minus1
=
1
-
_betta1
;
AVX_Data
betta1_minus1_4
;
betta1_minus1_4
.
data
=
SIMD_SET
(
betta1_minus1
);
float
betta2_minus1
=
1
-
_betta2
;
AVX_Data
betta2_minus1_4
;
betta2_minus1_4
.
data
=
SIMD_SET
(
betta2_minus1
);
AVX_Data
bias2_sqrt
;
bias2_sqrt
.
data
=
SIMD_SET
(
_bias_correction2
);
AVX_Data
eps_4
;
eps_4
.
data
=
SIMD_SET
(
_eps
);
float
step_size
=
-
1
*
_alpha
/
_bias_correction1
;
AVX_Data
step_size_4
;
step_size_4
.
data
=
SIMD_SET
(
step_size
);
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
AVX_Data
weight_decay_4
;
if
(
_weight_decay
>
0
)
weight_decay_4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
rounded_size
=
ROUND_DOWN
(
_param_size
,
SIMD_WIDTH
*
8
);
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
#pragma omp parallel for
for
(
size_t
i
=
t
;
i
<
offset
;
i
+=
SIMD_WIDTH
*
8
)
{
AVX_Data
grad_4
[
8
];
AVX_Data
momentum_4
[
8
];
AVX_Data
variance_4
[
8
];
AVX_Data
param_4
[
8
];
#pragma unroll 8
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
if
(
grad_half_precision
)
{
grad_4
[
j
].
data
=
SIMD_LOAD_HALF
(
grads_cast_h
+
i
+
SIMD_WIDTH
*
j
);
}
else
{
grad_4
[
j
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
*
j
);
}
if
(
loss_scale
>
0
)
{
AVX_Data
loss_scale_vec
;
loss_scale_vec
.
data
=
SIMD_SET
(
loss_scale
);
grad_4
[
j
].
data
=
SIMD_DIV
(
grad_4
[
j
].
data
,
loss_scale_vec
.
data
);
}
momentum_4
[
j
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
*
j
);
variance_4
[
j
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
j
);
if
(
param_half_precision
)
{
param_4
[
j
].
data
=
SIMD_LOAD_HALF
(
params_cast_h
+
i
+
SIMD_WIDTH
*
j
);
}
else
{
param_4
[
j
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
*
j
);
}
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
grad_4
[
j
].
data
);
}
momentum_4
[
j
].
data
=
SIMD_MUL
(
momentum_4
[
j
].
data
,
betta1_4
.
data
);
momentum_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
j
].
data
);
variance_4
[
j
].
data
=
SIMD_MUL
(
variance_4
[
j
].
data
,
betta2_4
.
data
);
grad_4
[
j
].
data
=
SIMD_MUL
(
grad_4
[
j
].
data
,
grad_4
[
j
].
data
);
variance_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_SQRT
(
variance_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
j
].
data
=
SIMD_DIV
(
momentum_4
[
j
].
data
,
grad_4
[
j
].
data
);
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
param_4
[
j
].
data
);
}
param_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
step_size_4
.
data
,
param_4
[
j
].
data
);
if
(
param_half_precision
)
{
SIMD_STORE_HALF
((
float
*
)(
params_cast_h
+
i
+
SIMD_WIDTH
*
j
),
param_4
[
j
].
data
);
}
else
{
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
j
,
param_4
[
j
].
data
);
}
SIMD_STORE
(
_exp_avg
+
i
+
(
SIMD_WIDTH
*
j
),
momentum_4
[
j
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
(
SIMD_WIDTH
*
j
),
variance_4
[
j
].
data
);
}
}
}
#endif
if
(
_param_size
>
rounded_size
)
Step_4
((
param_half_precision
?
(
float
*
)(
params_cast_h
+
rounded_size
)
:
_params
+
rounded_size
),
(
grad_half_precision
?
(
float
*
)(
grads_cast_h
+
rounded_size
)
:
grads
+
rounded_size
),
(
_exp_avg
+
rounded_size
),
(
_exp_avg_sq
+
rounded_size
),
(
_param_size
-
rounded_size
),
param_half_precision
,
grad_half_precision
,
loss_scale
);
}
void
Adam_Optimizer
::
step
(
size_t
step
,
float
lr
,
float
beta1
,
float
beta2
,
float
epsilon
,
float
weight_decay
,
bool
bias_correction
,
torch
::
Tensor
&
params
,
torch
::
Tensor
&
grads
,
torch
::
Tensor
&
exp_avg
,
torch
::
Tensor
&
exp_avg_sq
,
float
loss_scale
)
{
auto
params_c
=
params
.
contiguous
();
auto
grads_c
=
grads
.
contiguous
();
auto
exp_avg_c
=
exp_avg
.
contiguous
();
auto
exp_avg_sq_c
=
exp_avg_sq
.
contiguous
();
float
*
params_ptr
=
(
float
*
)
params_c
.
data_ptr
();
float
*
grads_ptr
=
(
float
*
)
grads_c
.
data_ptr
();
float
*
exp_avg_ptr
=
(
float
*
)
exp_avg_c
.
data_ptr
();
float
*
exp_avg_sq_ptr
=
(
float
*
)
exp_avg_sq_c
.
data_ptr
();
this
->
IncrementStep
(
step
,
beta1
,
beta2
);
this
->
update_state
(
lr
,
epsilon
,
weight_decay
,
bias_correction
);
this
->
Step_8
(
params_ptr
,
grads_ptr
,
exp_avg_ptr
,
exp_avg_sq_ptr
,
params_c
.
numel
(),
(
params
.
options
().
dtype
()
==
at
::
kHalf
),
(
grads
.
options
().
dtype
()
==
at
::
kHalf
),
loss_scale
);
}
namespace
py
=
pybind11
;
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
py
::
class_
<
Adam_Optimizer
>
(
m
,
"CPUAdamOptimizer"
)
.
def
(
py
::
init
<
float
,
float
,
float
,
float
,
float
,
bool
>
())
.
def
(
"step"
,
&
Adam_Optimizer
::
step
);
}
colossalai/kernel/cuda_native/csrc/cpu_adam.h
0 → 100644
View file @
08f2920e
/*
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
*/
#pragma once
#include <cublas_v2.h>
#include <cuda.h>
#include <hip/hip_fp16.h>
#include <cuda_runtime_api.h>
#include <stdio.h>
#include <torch/extension.h>
#if (__x86_64__ || __i386__)
#include <cpuid.h>
#include <x86intrin.h>
#endif
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
#define TILE (128 * 1024 * 1024)
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
#if defined(__AVX512__)
#define SIMD_WIDTH 16
#define INTV __m256i
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm512_loadu_ps(x)
#define SIMD_SET(x) _mm512_set1_ps(x)
#define SIMD_ADD(x, y) _mm512_add_ps(x, y)
#define SIMD_MUL(x, y) _mm512_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_LOAD_HALF(x) \
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm256_store_ps( \
x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#elif defined(__AVX256__) or defined(__AVX2__)
#define SIMD_WIDTH 8
#define INTV __m128i
#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm256_loadu_ps(x)
#define SIMD_SET(x) _mm256_set1_ps(x)
#define SIMD_ADD(x, y) _mm256_add_ps(x, y)
#define SIMD_MUL(x, y) _mm256_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm_store_ps( \
x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#endif
union
AVX_Data
{
#if defined(__AVX512__)
__m512
data
;
#elif defined(__AVX256__) or defined(__AVX2__)
__m256
data
;
#endif
// float data_f[16];
};
#endif
#define STEP(SPAN) \
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
float *_exp_avg_sq, size_t _param_size, \
bool param_half_precision = false, \
bool grad_half_precision = false, float loss_scale = -1);
class
Adam_Optimizer
{
public:
Adam_Optimizer
(
float
alpha
=
1e-3
,
float
betta1
=
0.9
,
float
betta2
=
0.999
,
float
eps
=
1e-8
,
float
weight_decay
=
0
,
bool
adamw_mode
=
true
)
:
_alpha
(
alpha
),
_betta1
(
betta1
),
_betta2
(
betta2
),
_eps
(
eps
),
_weight_decay
(
weight_decay
),
_betta1_t
(
1.0
),
_betta2_t
(
1.0
),
_step
(
0
),
_adamw_mode
(
adamw_mode
)
{}
~
Adam_Optimizer
()
{}
STEP
(
1
)
STEP
(
4
)
STEP
(
8
)
inline
void
IncrementStep
(
size_t
step
,
float
beta1
,
float
beta2
)
{
if
(
beta1
!=
_betta1
||
beta2
!=
_betta2
)
{
_step
=
step
;
_betta1
=
beta1
;
_betta2
=
beta2
;
_betta1_t
=
std
::
pow
(
_betta1
,
step
);
_betta2_t
=
std
::
pow
(
_betta2
,
step
);
}
else
{
_step
++
;
if
(
_step
!=
step
)
{
_betta1_t
=
std
::
pow
(
_betta1
,
step
);
_betta2_t
=
std
::
pow
(
_betta2
,
step
);
_step
=
step
;
}
else
{
_betta1_t
*=
_betta1
;
_betta2_t
*=
_betta2
;
}
}
}
inline
void
update_state
(
float
lr
,
float
epsilon
,
float
weight_decay
,
bool
bias_correction
)
{
_alpha
=
lr
;
_eps
=
epsilon
;
_weight_decay
=
weight_decay
;
_bias_correction1
=
1.0
f
;
_bias_correction2
=
1.0
f
;
if
(
bias_correction
==
1
)
{
_bias_correction1
=
1
-
_betta1_t
;
_bias_correction2
=
1
/
sqrt
(
1
-
_betta2_t
);
}
}
void
step
(
size_t
step
,
float
lr
,
float
beta1
,
float
beta2
,
float
epsilon
,
float
weight_decay
,
bool
bias_correction
,
torch
::
Tensor
&
params
,
torch
::
Tensor
&
grads
,
torch
::
Tensor
&
exp_avg
,
torch
::
Tensor
&
exp_avg_sq
,
float
loss_scale
);
private:
float
_alpha
;
float
_betta1
;
float
_betta2
;
float
_eps
;
float
_weight_decay
;
float
_betta1_t
;
float
_betta2_t
;
size_t
_step
;
float
_bias_correction1
;
float
_bias_correction2
;
bool
_adamw_mode
;
};
colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu
0 → 100644
View file @
08f2920e
#include "block_reduce.h"
#include "cuda_util.h"
#include "kernels.h"
#include "ls_cub.cuh"
ls
::
cub
::
CachingDeviceAllocator
g_allocator
(
true
);
template
<
typename
T
>
__global__
void
ls_cross_entropy_fw_kernel
(
const
T
*
__restrict__
inputs
,
const
int
*
__restrict__
targets
,
float
*
__restrict__
outputs
,
float
*
__restrict__
nll_loss_outputs
,
const
int
padding_idx
,
const
float
epsilon
,
const
int
vocab_size
)
{
/* step1: compute each thread's max_logit and sum_exp_logit, store in
* max_input, sum_exp_logit */
const
int
block_start
=
blockIdx
.
x
*
vocab_size
;
const
int
left_idx
=
block_start
+
threadIdx
.
x
;
const
int
right_idx
=
(
blockIdx
.
x
+
1
)
*
vocab_size
;
float
max_input
[
1
]
=
{
REDUCE_FLOAT_INF_NEG
};
float
sum_logits
[
2
]
=
{
0.
f
,
0.
f
};
// logit and logit exp
int
target_tid
=
targets
[
blockIdx
.
x
];
if
(
target_tid
==
padding_idx
)
{
if
(
threadIdx
.
x
==
0
)
{
nll_loss_outputs
[
blockIdx
.
x
]
=
0.
f
;
outputs
[
blockIdx
.
x
]
=
0.
f
;
}
return
;
}
for
(
int
i
=
left_idx
;
i
<
right_idx
;
i
+=
blockDim
.
x
)
{
max_input
[
0
]
=
fmaxf
(
max_input
[
0
],
static_cast
<
float
>
(
inputs
[
i
]));
}
blockReduce
<
ReduceType
::
kMax
,
1
>
(
max_input
);
__shared__
float
s_max_input
;
if
(
threadIdx
.
x
==
0
)
{
s_max_input
=
max_input
[
0
];
}
__syncthreads
();
for
(
int
i
=
left_idx
;
i
<
right_idx
;
i
+=
blockDim
.
x
)
{
float
logit
=
static_cast
<
float
>
(
inputs
[
i
])
-
s_max_input
;
sum_logits
[
0
]
+=
logit
;
sum_logits
[
1
]
+=
expf
(
logit
);
}
blockReduce
<
ReduceType
::
kSum
,
2
>
(
sum_logits
);
__shared__
float
s_sum_logit
;
__shared__
float
s_sum_exp
;
if
(
threadIdx
.
x
==
0
)
{
s_sum_logit
=
sum_logits
[
0
];
s_sum_exp
=
sum_logits
[
1
];
}
__syncthreads
();
float
eps_i
=
epsilon
/
(
vocab_size
-
1
);
if
(
threadIdx
.
x
==
0
)
{
// neg_log_prob = log(sum(exp(x - x_max))) - (x - x_max)
float
nll_loss
=
logf
(
s_sum_exp
)
-
static_cast
<
float
>
(
inputs
[
block_start
+
target_tid
])
+
s_max_input
;
nll_loss_outputs
[
blockIdx
.
x
]
=
nll_loss
;
float
sum_nll_loss
=
vocab_size
*
logf
(
s_sum_exp
)
-
s_sum_logit
;
outputs
[
blockIdx
.
x
]
=
(
1.
f
-
epsilon
-
eps_i
)
*
nll_loss
+
eps_i
*
sum_nll_loss
;
}
}
template
<
typename
T
>
__global__
void
ls_cross_entropy_bw_kernel
(
const
float
*
__restrict__
grad_outputs
,
const
T
*
__restrict__
inputs
,
const
int
*
__restrict__
targets
,
T
*
__restrict__
grad_inputs
,
const
int
padding_idx
,
const
float
epsilon
,
const
int
vocab_size
)
{
/* step1: compute each thread's max_logit and sum_exp_logit, store in
* max_input, sum_exp_logit */
const
int
block_start
=
blockIdx
.
x
*
vocab_size
;
const
int
left_idx
=
block_start
+
threadIdx
.
x
;
const
int
right_idx
=
(
blockIdx
.
x
+
1
)
*
vocab_size
;
float
max_input
[
1
]
=
{
REDUCE_FLOAT_INF_NEG
};
float
sum_logits
[
1
]
=
{
0.
f
};
const
float
grad_out
=
static_cast
<
float
>
(
grad_outputs
[
0
]);
int
target_tid
=
targets
[
blockIdx
.
x
];
if
(
target_tid
==
padding_idx
)
{
for
(
int
i
=
left_idx
;
i
<
right_idx
;
i
+=
blockDim
.
x
)
{
grad_inputs
[
i
]
=
0.
f
;
}
return
;
}
for
(
int
i
=
left_idx
;
i
<
right_idx
;
i
+=
blockDim
.
x
)
{
max_input
[
0
]
=
fmaxf
(
max_input
[
0
],
static_cast
<
float
>
(
inputs
[
i
]));
}
blockReduce
<
ReduceType
::
kMax
,
1
>
(
max_input
);
__shared__
float
s_max_input
;
if
(
threadIdx
.
x
==
0
)
{
s_max_input
=
max_input
[
0
];
}
__syncthreads
();
for
(
int
i
=
left_idx
;
i
<
right_idx
;
i
+=
blockDim
.
x
)
{
float
logit
=
static_cast
<
float
>
(
inputs
[
i
])
-
s_max_input
;
sum_logits
[
0
]
+=
expf
(
logit
);
}
blockReduce
<
ReduceType
::
kSum
,
1
>
(
sum_logits
);
__shared__
float
s_sum_exp
;
if
(
threadIdx
.
x
==
0
)
{
s_sum_exp
=
sum_logits
[
0
];
}
__syncthreads
();
float
eps_i
=
epsilon
/
(
vocab_size
-
1
);
float
nll_weight
=
1.0
-
epsilon
-
eps_i
;
for
(
int
i
=
left_idx
;
i
<
right_idx
;
i
+=
blockDim
.
x
)
{
float
prob
=
expf
(
static_cast
<
float
>
(
inputs
[
i
])
-
s_max_input
)
/
s_sum_exp
;
float
grad
=
0
;
grad
+=
(
vocab_size
*
prob
-
1
)
*
eps_i
;
grad
+=
prob
*
nll_weight
;
if
((
i
-
block_start
)
==
target_tid
)
{
grad
-=
nll_weight
;
}
grad_inputs
[
i
]
=
grad_out
*
grad
;
}
}
template
<
typename
T
>
void
launch_cross_entropy_fw
(
const
T
*
inputs_ptr
,
const
int
*
targets_ptr
,
float
*
outputs_ptr
,
float
*
nll_loss_ptr
,
float
*
loss_buffer
,
const
int
padding_idx
,
const
float
epsilon
,
const
int
batch_size
,
const
int
seq_len
,
const
int
vocab_size
,
cudaStream_t
stream
)
{
int
grid_dim
=
batch_size
*
seq_len
;
float
*
nll_loss_buffer
=
loss_buffer
+
grid_dim
;
ls_cross_entropy_fw_kernel
<<<
grid_dim
,
MAX_THREADS
,
0
,
stream
>>>
(
inputs_ptr
,
targets_ptr
,
loss_buffer
,
nll_loss_buffer
,
padding_idx
,
epsilon
,
vocab_size
);
int
num_items
=
grid_dim
;
void
*
d_temp_storage
=
NULL
;
size_t
temp_storage_bytes
=
0
;
CHECK_GPU_ERROR
(
ls
::
cub
::
DeviceReduce
::
Sum
(
d_temp_storage
,
temp_storage_bytes
,
loss_buffer
,
outputs_ptr
,
num_items
,
stream
));
CHECK_GPU_ERROR
(
g_allocator
.
DeviceAllocate
(
&
d_temp_storage
,
temp_storage_bytes
));
CHECK_GPU_ERROR
(
ls
::
cub
::
DeviceReduce
::
Sum
(
d_temp_storage
,
temp_storage_bytes
,
loss_buffer
,
outputs_ptr
,
num_items
,
stream
));
CHECK_GPU_ERROR
(
ls
::
cub
::
DeviceReduce
::
Sum
(
d_temp_storage
,
temp_storage_bytes
,
nll_loss_buffer
,
nll_loss_ptr
,
num_items
,
stream
));
CHECK_GPU_ERROR
(
g_allocator
.
DeviceFree
(
d_temp_storage
));
}
template
void
launch_cross_entropy_fw
<
float
>(
const
float
*
inputs_ptr
,
const
int
*
targets_ptr
,
float
*
outputs_ptr
,
float
*
nll_loss_ptr
,
float
*
loss_buffer
,
const
int
padding_idx
,
const
float
epsilon
,
const
int
batch_size
,
const
int
seq_len
,
const
int
vocab_size
,
cudaStream_t
stream
);
template
void
launch_cross_entropy_fw
<
__half
>(
const
__half
*
inputs_ptr
,
const
int
*
targets_ptr
,
float
*
outputs_ptr
,
float
*
nll_loss_ptr
,
float
*
loss_buffer
,
const
int
padding_idx
,
const
float
epsilon
,
const
int
batch_size
,
const
int
seq_len
,
const
int
vocab_size
,
cudaStream_t
stream
);
template
<
typename
T
>
void
launch_cross_entropy_bw
(
const
float
*
grad_outputs_ptr
,
const
T
*
inputs_ptr
,
const
int
*
targets_ptr
,
T
*
grad_inputs_ptr
,
const
int
padding_idx
,
const
float
epsilon
,
const
int
batch_size
,
const
int
seq_len
,
const
int
vocab_size
,
cudaStream_t
stream
)
{
int
grid_dim
=
batch_size
*
seq_len
;
ls_cross_entropy_bw_kernel
<<<
grid_dim
,
MAX_THREADS
,
0
,
stream
>>>
(
grad_outputs_ptr
,
inputs_ptr
,
targets_ptr
,
grad_inputs_ptr
,
padding_idx
,
epsilon
,
vocab_size
);
}
template
void
launch_cross_entropy_bw
<
float
>(
const
float
*
grad_outputs_ptr
,
const
float
*
inputs_ptr
,
const
int
*
targets_ptr
,
float
*
grad_inputs_ptr
,
const
int
padding_idx
,
const
float
epsilon
,
const
int
batch_size
,
const
int
seq_len
,
const
int
vocab_size
,
cudaStream_t
stream
);
template
void
launch_cross_entropy_bw
<
__half
>(
const
float
*
grad_outputs_ptr
,
const
__half
*
inputs_ptr
,
const
int
*
targets_ptr
,
__half
*
grad_inputs_ptr
,
const
int
padding_idx
,
const
float
epsilon
,
const
int
batch_size
,
const
int
seq_len
,
const
int
vocab_size
,
cudaStream_t
stream
);
colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu
0 → 100644
View file @
08f2920e
/* Copyright 2021 The LightSeq Team
Copyright Microsoft DeepSpeed
This file is adapted from Microsoft DeepSpeed
*/
#include "cublas_wrappers.h"
#ifdef COLOSSAL_HIP
int
cublas_gemm_ex
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
rocblas_gemm_algo
algo
)
{
cublasStatus_t
status
=
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
rocblas_datatype_f32_r
,
(
transa
==
rocblas_operation_none
)
?
m
:
k
,
(
const
void
*
)
B
,
rocblas_datatype_f32_r
,
(
transb
==
rocblas_operation_none
)
?
k
:
n
,
(
const
void
*
)
beta
,
C
,
rocblas_datatype_f32_r
,
m
,
C
,
rocblas_datatype_f32_r
,
m
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_gemm_ex
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
rocblas_gemm_algo
algo
)
{
cublasStatus_t
status
=
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
rocblas_datatype_f16_r
,
(
transa
==
rocblas_operation_none
)
?
m
:
k
,
(
const
void
*
)
B
,
rocblas_datatype_f16_r
,
(
transb
==
rocblas_operation_none
)
?
k
:
n
,
(
const
void
*
)
beta
,
(
void
*
)
C
,
rocblas_datatype_f16_r
,
m
,
(
void
*
)
C
,
rocblas_datatype_f16_r
,
m
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_strided_batched_gemm
(
cublasHandle_t
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
cublasOperation_t
op_A
,
cublasOperation_t
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
rocblas_gemm_algo
algo
)
{
cublasStatus_t
status
=
rocblas_gemm_strided_batched_ex
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f32_r
,
(
op_A
==
rocblas_operation_none
)
?
m
:
k
,
stride_A
,
B
,
rocblas_datatype_f32_r
,
(
op_B
==
rocblas_operation_none
)
?
k
:
n
,
stride_B
,
beta
,
C
,
rocblas_datatype_f32_r
,
m
,
stride_C
,
C
,
rocblas_datatype_f16_r
,
m
,
stride_C
,
batch
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, "
"error: %d)
\n
"
,
batch
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_strided_batched_gemm
(
cublasHandle_t
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
cublasOperation_t
op_A
,
cublasOperation_t
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
rocblas_gemm_algo
algo
)
{
cublasStatus_t
status
=
rocblas_gemm_strided_batched_ex
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f16_r
,
(
op_A
==
rocblas_operation_none
)
?
m
:
k
,
stride_A
,
B
,
rocblas_datatype_f16_r
,
(
op_B
==
rocblas_operation_none
)
?
k
:
n
,
stride_B
,
beta
,
C
,
rocblas_datatype_f16_r
,
m
,
stride_C
,
C
,
rocblas_datatype_f16_r
,
m
,
stride_C
,
batch
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
#else
int
cublas_gemm_ex
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
CUDA_R_32F
,
(
transa
==
CUBLAS_OP_N
)
?
m
:
k
,
(
const
void
*
)
B
,
CUDA_R_32F
,
(
transb
==
CUBLAS_OP_N
)
?
k
:
n
,
(
const
void
*
)
beta
,
C
,
CUDA_R_32F
,
m
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_gemm_ex
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
CUDA_R_16F
,
(
transa
==
CUBLAS_OP_N
)
?
m
:
k
,
(
const
void
*
)
B
,
CUDA_R_16F
,
(
transb
==
CUBLAS_OP_N
)
?
k
:
n
,
(
const
void
*
)
beta
,
(
void
*
)
C
,
CUDA_R_16F
,
m
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_strided_batched_gemm
(
cublasHandle_t
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
cublasOperation_t
op_A
,
cublasOperation_t
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_32F
,
(
op_A
==
CUBLAS_OP_N
)
?
m
:
k
,
stride_A
,
B
,
CUDA_R_32F
,
(
op_B
==
CUBLAS_OP_N
)
?
k
:
n
,
stride_B
,
beta
,
C
,
CUDA_R_32F
,
m
,
stride_C
,
batch
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, "
"error: %d)
\n
"
,
batch
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_strided_batched_gemm
(
cublasHandle_t
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
cublasOperation_t
op_A
,
cublasOperation_t
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16F
,
(
op_A
==
CUBLAS_OP_N
)
?
m
:
k
,
stride_A
,
B
,
CUDA_R_16F
,
(
op_B
==
CUBLAS_OP_N
)
?
k
:
n
,
stride_B
,
beta
,
C
,
CUDA_R_16F
,
m
,
stride_C
,
batch
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
#endif
colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu
0 → 100644
View file @
08f2920e
#include <thrust/device_vector.h>
#include <thrust/reduce.h>
#ifdef COLOSSAL_HIP
#include <thrust/transform_reduce.h>
#include "hip_util.h"
#else
#include "cuda_util.h"
#endif
/* GPU function guard */
std
::
string
_cudaGetErrorString
(
cudaError_t
error
)
{
return
cudaGetErrorString
(
error
);
}
std
::
string
_cudaGetErrorString
(
cublasStatus_t
error
)
{
switch
(
error
)
{
case
CUBLAS_STATUS_SUCCESS
:
return
"CUBLAS_STATUS_SUCCESS"
;
case
CUBLAS_STATUS_NOT_INITIALIZED
:
return
"CUBLAS_STATUS_NOT_INITIALIZED"
;
case
CUBLAS_STATUS_ALLOC_FAILED
:
return
"CUBLAS_STATUS_ALLOC_FAILED"
;
case
CUBLAS_STATUS_INVALID_VALUE
:
return
"CUBLAS_STATUS_INVALID_VALUE"
;
case
CUBLAS_STATUS_ARCH_MISMATCH
:
return
"CUBLAS_STATUS_ARCH_MISMATCH"
;
#ifndef COLOSSAL_HIP
case
CUBLAS_STATUS_MAPPING_ERROR
:
return
"CUBLAS_STATUS_MAPPING_ERROR"
;
case
CUBLAS_STATUS_EXECUTION_FAILED
:
return
"CUBLAS_STATUS_EXECUTION_FAILED"
;
case
CUBLAS_STATUS_INTERNAL_ERROR
:
return
"CUBLAS_STATUS_INTERNAL_ERROR"
;
case
CUBLAS_STATUS_NOT_SUPPORTED
:
return
"CUBLAS_STATUS_NOT_SUPPORTED"
;
case
CUBLAS_STATUS_LICENSE_ERROR
:
return
"CUBLAS_STATUS_LICENSE_ERROR"
;
#endif
}
return
"CUBLAS_UNKNOW"
;
}
template
<
typename
T
>
void
check_gpu_error
(
T
result
,
char
const
*
const
func
,
const
char
*
const
file
,
int
const
line
)
{
if
(
result
)
{
throw
std
::
runtime_error
(
std
::
string
(
"[CUDA][ERROR] "
)
+
+
file
+
"("
+
std
::
to_string
(
line
)
+
"): "
+
(
_cudaGetErrorString
(
result
))
+
"
\n
"
);
}
}
template
void
check_gpu_error
<
cudaError_t
>(
cudaError_t
result
,
char
const
*
const
func
,
const
char
*
const
file
,
int
const
line
);
template
void
check_gpu_error
<
cublasStatus_t
>(
cublasStatus_t
result
,
char
const
*
const
func
,
const
char
*
const
file
,
int
const
line
);
template
<
typename
T
>
void
print_vec
(
const
T
*
outv
,
std
::
string
outn
,
int
num_output_ele
)
{
std
::
cout
<<
outn
<<
": "
;
std
::
vector
<
T
>
hout
(
num_output_ele
,
(
T
)
0
);
cudaMemcpy
(
hout
.
data
(),
outv
,
num_output_ele
*
sizeof
(
T
),
cudaMemcpyDeviceToHost
);
for
(
int
i
=
0
;
i
<
num_output_ele
;
i
++
)
{
std
::
cout
<<
hout
[
i
]
<<
", "
;
}
std
::
cout
<<
std
::
endl
;
}
template
<
>
void
print_vec
<
__half
>
(
const
__half
*
outv
,
std
::
string
outn
,
int
num_output_ele
)
{
std
::
cout
<<
outn
<<
": "
;
std
::
vector
<
__half
>
hout
(
num_output_ele
,
(
__half
)
0.
f
);
cudaMemcpy
(
hout
.
data
(),
outv
,
num_output_ele
*
sizeof
(
__half
),
cudaMemcpyDeviceToHost
);
for
(
int
i
=
0
;
i
<
num_output_ele
;
i
++
)
{
std
::
cout
<<
__half2float
(
hout
[
i
])
<<
", "
;
}
std
::
cout
<<
std
::
endl
;
}
template
void
print_vec
<
float
>(
const
float
*
outv
,
std
::
string
outn
,
int
num_output_ele
);
template
void
print_vec
<
int
>(
const
int
*
outv
,
std
::
string
outn
,
int
num_output_ele
);
template
void
print_vec
<
__half
>(
const
__half
*
outv
,
std
::
string
outn
,
int
num_output_ele
);
template
<
typename
T
>
T
*
cuda_malloc
(
size_t
ele_num
)
{
size_t
byte_size
=
ele_num
*
sizeof
(
T
);
T
*
pdata
=
nullptr
;
CHECK_GPU_ERROR
(
cudaMalloc
((
void
**
)
&
pdata
,
byte_size
));
return
pdata
;
}
template
float
*
cuda_malloc
<
float
>(
size_t
ele_num
);
template
__half
*
cuda_malloc
<
__half
>(
size_t
ele_num
);
template
uint8_t
*
cuda_malloc
<
uint8_t
>(
size_t
ele_num
);
void
cuda_free
(
void
*
pdata
)
{
if
(
pdata
!=
nullptr
)
{
cudaFree
(
pdata
);
}
}
template
<
typename
T
>
struct
_isnan
{
__device__
bool
operator
()(
T
a
)
const
{
return
isnan
(
a
);
}
};
template
<
>
struct
_isnan
<
__half
>
{
__device__
bool
operator
()(
const
__half
a
)
const
{
return
__hisnan
(
a
);
}
};
template
<
typename
T
>
struct
_isinf
{
__device__
bool
operator
()(
T
a
)
const
{
return
isinf
(
a
);
}
};
template
<
>
struct
_isinf
<
__half
>
{
__device__
bool
operator
()(
const
__half
a
)
const
{
return
__hisinf
(
a
);
}
};
template
<
typename
T
>
void
check_nan_inf
(
const
T
*
data_ptr
,
int
dsize
,
bool
check_nan_inf
,
std
::
string
file
,
int
line
,
cudaStream_t
stream
)
{
// check_nan_inf = 0 for checking nan
// check_nan_inf = 1 for checking inf
bool
res
=
false
;
std
::
string
msg
=
file
+
"("
+
std
::
to_string
(
line
)
+
"): "
;
if
(
check_nan_inf
)
{
msg
+=
"nan."
;
res
=
thrust
::
transform_reduce
(
thrust
::
cuda
::
par
.
on
(
stream
),
data_ptr
,
data_ptr
+
dsize
,
_isnan
<
T
>
(),
false
,
thrust
::
logical_or
<
bool
>
());
}
else
{
msg
+=
"inf."
;
res
=
thrust
::
transform_reduce
(
thrust
::
cuda
::
par
.
on
(
stream
),
data_ptr
,
data_ptr
+
dsize
,
_isinf
<
T
>
(),
false
,
thrust
::
logical_or
<
bool
>
());
}
if
(
res
)
{
throw
std
::
runtime_error
(
msg
);
}
std
::
cout
<<
msg
<<
" [check pass]."
<<
std
::
endl
;
}
template
void
check_nan_inf
<
float
>(
const
float
*
data_ptr
,
int
dsize
,
bool
check_nan_inf
,
std
::
string
file
,
int
line
,
cudaStream_t
stream
);
template
void
check_nan_inf
<
__half
>(
const
__half
*
data_ptr
,
int
dsize
,
bool
check_nan_inf
,
std
::
string
file
,
int
line
,
cudaStream_t
stream
);
colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
0 → 100644
View file @
08f2920e
#include <chrono>
#include <ctime>
#include "kernels.h"
#ifdef COLOSSAL_HIP
#include <hiprand/hiprand_kernel_hcc.h>
#endif
#ifndef COLOSSAL_HIP
#include <cooperative_groups.h>
namespace
cg
=
cooperative_groups
;
#endif
curandStatePhilox4_32_10_t
*
curandstate
;
/**
* @brief element-wise activation function on device, like Relu, Gelu
*
* @tparam enum class ActivationType, kRelu, kGelu
* @tparam input type
* @param any shape of float and __half2
* @return same shape and type with input
*/
template
<
ActivationType
,
typename
T
>
__forceinline__
__device__
T
activation_kernel
(
T
x
);
template
<
>
__device__
float
activation_kernel
<
ActivationType
::
kGelu
,
float
>
(
float
x
)
{
float
cdf
=
0.5
f
*
(
1.0
f
+
tanhf
((
0.7978845608028654
f
*
(
x
+
0.044715
f
*
x
*
x
*
x
))));
return
x
*
cdf
;
}
template
<
>
__device__
__half2
activation_kernel
<
ActivationType
::
kGelu
,
__half2
>
(
__half2
val
)
{
__half2
val_pow3
=
__hmul2
(
val
,
__hmul2
(
val
,
val
));
float2
tmp_pow
=
__half22float2
(
val_pow3
);
float2
tmp
=
__half22float2
(
val
);
tmp
.
x
=
0.5
f
*
(
1.0
f
+
tanhf
((
0.7978845608028654
f
*
(
tmp
.
x
+
0.044715
f
*
tmp_pow
.
x
))));
tmp
.
y
=
0.5
f
*
(
1.0
f
+
tanhf
((
0.7978845608028654
f
*
(
tmp
.
y
+
0.044715
f
*
tmp_pow
.
y
))));
return
__hmul2
(
val
,
__float22half2_rn
(
tmp
));
}
template
<
>
__device__
float
activation_kernel
<
ActivationType
::
kRelu
,
float
>
(
float
x
)
{
return
fmaxf
(
x
,
0
);
}
template
<
>
__device__
__half2
activation_kernel
<
ActivationType
::
kRelu
,
__half2
>
(
__half2
x
)
{
#ifdef COLOSSAL_HIP
float2
tmp
=
__half22float2
(
x
);
return
__floats2half2_rn
(
fmaxf
(
0.
f
,
tmp
.
x
),
fmaxf
(
0.
f
,
tmp
.
y
));
#else
return
__floats2half2_rn
(
fmaxf
(
0.
f
,
__half2float
(
x
.
x
)),
fmaxf
(
0.
f
,
__half2float
(
x
.
y
)));
#endif
}
/**
* @brief element-wise activation backward function on device
*
* @tparam enum class ActivationType
* @tparam input type
* @param any shape of float and __half2
* @return same shape of input
*/
template
<
ActivationType
,
typename
T
>
__forceinline__
__device__
T
activation_bwd_kernel
(
T
grad
,
T
x
);
template
<
>
__device__
float
activation_bwd_kernel
<
ActivationType
::
kGelu
,
float
>
(
float
grad
,
float
x
)
{
const
float
sqrt_param
=
0.79788456080286535587989211986876
f
;
const
float
mul_param
=
0.044715
;
float
x2mul
=
x
*
x
*
mul_param
;
float
tan_h
=
tanhf
(
sqrt_param
*
(
x
+
x
*
x2mul
));
float
dg1
=
0.5
f
*
(
1.0
f
+
tan_h
);
float
dg2
=
x
*
0.5
f
*
sqrt_param
*
(
1
-
tan_h
*
tan_h
);
float
dg3
=
dg2
*
3
*
x2mul
;
return
grad
*
(
dg1
+
dg2
+
dg3
);
}
template
<
>
__device__
__half
activation_bwd_kernel
<
ActivationType
::
kGelu
,
__half
>
(
__half
grad
,
__half
x_half
)
{
float
x
=
__half2float
(
x_half
);
const
float
sqrt_param
=
0.79788456080286535587989211986876
f
;
const
float
mul_param
=
0.044715
;
float
x2mul
=
x
*
x
*
mul_param
;
float
tan_h
=
tanhf
(
sqrt_param
*
(
x
+
x
*
x2mul
));
float
dg1
=
0.5
f
*
(
1.0
f
+
tan_h
);
float
dg2
=
x
*
0.5
f
*
sqrt_param
*
(
1
-
tan_h
*
tan_h
);
float
dg3
=
dg2
*
3
*
x2mul
;
return
grad
*
__float2half
(
dg1
+
dg2
+
dg3
);
}
template
<
>
__device__
float
activation_bwd_kernel
<
ActivationType
::
kRelu
,
float
>
(
float
grad
,
float
x
)
{
return
x
>
0.
f
?
grad
:
0.
f
;
}
template
<
>
__device__
__half
activation_bwd_kernel
<
ActivationType
::
kRelu
,
__half
>
(
__half
grad
,
__half
x
)
{
const
__half
half_zero
=
__float2half
(
0.
f
);
return
x
>
half_zero
?
grad
:
half_zero
;
}
template
<
>
__device__
__half2
activation_bwd_kernel
<
ActivationType
::
kRelu
,
__half2
>
(
__half2
grad2
,
__half2
x_half2
)
{
#ifdef COLOSSAL_HIP
float2
tmp_x
=
__half22float2
(
x_half2
);
float2
tmp_grad2
=
__half22float2
(
grad2
);
return
__floats2half2_rn
(
tmp_x
.
x
>
0.0
?
tmp_grad2
.
x
:
0.0
,
tmp_x
.
y
>
0.0
?
tmp_grad2
.
y
:
0.0
);
#else
const
__half
half_zero
=
__float2half
(
0.
f
);
return
__floats2half2_rn
(
x_half2
.
x
>
half_zero
?
grad2
.
x
:
half_zero
,
x_half2
.
y
>
half_zero
?
grad2
.
y
:
half_zero
);
#endif
}
/**
* @brief init curand states in global memory
*
* @thread grid_dim * block*dim to suuport any size of states
* @param state persistant curand states
* @param seed seed to init states
* @return void
*/
__global__
void
curand_init_kernel
(
curandStatePhilox4_32_10_t
*
state
,
int
seed
)
{
/* Each thread gets same seed, a different sequence
number, no offset */
int
id
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
curand_init
(
seed
,
id
,
0
,
&
state
[
id
]);
}
void
launch_curand_init
(
int
total_count
,
int
dim
,
cudaStream_t
stream
)
{
cudaMalloc
(
&
curandstate
,
total_count
*
sizeof
(
curandStatePhilox4_32_10_t
));
int
grid_dim
=
total_count
>>
9
;
curand_init_kernel
<<<
grid_dim
,
512
,
0
,
stream
>>>
(
curandstate
,
std
::
chrono
::
duration_cast
<
std
::
chrono
::
microseconds
>
(
std
::
chrono
::
system_clock
::
now
().
time_since_epoch
())
.
count
());
}
/**
* @brief element-wise dropout, store dropped position in mask, it's not
* in-place
*
* @thread
* gridDim.x = total_count / 1024
* blockDim.x = 1024
*
* @param total_count total elements
* @param ratio drop ratio
* @param out any size of float and __half
* @param in same with out
* @param mask uint8 type, same size with out
* @param seed seed to curand
* @return void
*/
__global__
void
ls_dropout_kernel
(
const
int
total_count
,
const
float
ratio
,
float
*
__restrict__
out
,
const
float
*
__restrict__
in
,
uint8_t
*
__restrict__
mask
,
const
int
seed
)
{
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
4
>=
total_count
)
return
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
i
,
0
,
&
state
);
uint8_t
m
[
4
];
float4
*
out4
=
reinterpret_cast
<
float4
*>
(
out
);
const
float4
*
data4
=
reinterpret_cast
<
const
float4
*>
(
in
);
uint32_t
*
mask4
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
float4
rand
=
curand_uniform4
(
&
state
);
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
uint32_t
*
m4
=
reinterpret_cast
<
uint32_t
*>
(
m
);
mask4
[
i
]
=
m4
[
0
];
float4
input4
=
data4
[
i
];
float4
res4
;
res4
.
x
=
input4
.
x
*
scale
*
m
[
0
];
res4
.
y
=
input4
.
y
*
scale
*
m
[
1
];
res4
.
z
=
input4
.
z
*
scale
*
m
[
2
];
res4
.
w
=
input4
.
w
*
scale
*
m
[
3
];
out4
[
i
]
=
res4
;
}
__global__
void
ls_dropout_kernel
(
const
int
total_count
,
const
float
ratio
,
__half
*
__restrict__
out
,
const
__half
*
__restrict__
in
,
uint8_t
*
__restrict__
mask
,
const
int
seed
)
{
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
8
>=
total_count
)
return
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
i
,
0
,
&
state
);
const
float4
*
vals_float4
=
reinterpret_cast
<
const
float4
*>
(
in
);
float4
*
outs_float4
=
reinterpret_cast
<
float4
*>
(
out
);
uint64_t
*
mask8
=
reinterpret_cast
<
uint64_t
*>
(
mask
);
uint8_t
m
[
8
];
float4
rand
=
curand_uniform4
(
&
state
);
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
rand
=
curand_uniform4
(
&
state
);
m
[
4
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
5
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
6
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
7
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
uint64_t
*
m8
=
reinterpret_cast
<
uint64_t
*>
(
m
);
mask8
[
i
]
=
*
m8
;
float4
val_float4
=
vals_float4
[
i
];
float4
out_float4
;
__half2
*
val_half2
=
reinterpret_cast
<
__half2
*>
(
&
val_float4
);
__half2
*
out_half2
=
reinterpret_cast
<
__half2
*>
(
&
out_float4
);
__half2
scale_mask_1
=
__floats2half2_rn
(
scale
*
m
[
0
],
scale
*
m
[
1
]);
__half2
scale_mask_2
=
__floats2half2_rn
(
scale
*
m
[
2
],
scale
*
m
[
3
]);
__half2
scale_mask_3
=
__floats2half2_rn
(
scale
*
m
[
4
],
scale
*
m
[
5
]);
__half2
scale_mask_4
=
__floats2half2_rn
(
scale
*
m
[
6
],
scale
*
m
[
7
]);
out_half2
[
0
]
=
__hmul2
(
val_half2
[
0
],
scale_mask_1
);
out_half2
[
1
]
=
__hmul2
(
val_half2
[
1
],
scale_mask_2
);
out_half2
[
2
]
=
__hmul2
(
val_half2
[
2
],
scale_mask_3
);
out_half2
[
3
]
=
__hmul2
(
val_half2
[
3
],
scale_mask_4
);
outs_float4
[
i
]
=
out_float4
;
}
/**
* @brief element-wise dropout backward with dropout mask, it's
* not in-place
*
* @thread
* gridDim.x = total_count / 1024
* blockDim.x = 1024
*
* @param total_count total elements
* @param ratio drop ratio
* @param in any size of float and __half
* @param mask uint8 type, same size with in
* @return void
*/
__global__
void
ls_dropout_bwd_kernel
(
const
int
total_count
,
const
float
ratio
,
float
*
out
,
const
float
*
in
,
const
uint8_t
*
__restrict__
mask
)
{
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
4
>=
total_count
)
return
;
uint8_t
m
[
4
];
float4
*
out4
=
reinterpret_cast
<
float4
*>
(
out
);
const
float4
*
in4
=
reinterpret_cast
<
const
float4
*>
(
in
);
const
uint32_t
*
mask4
=
reinterpret_cast
<
const
uint32_t
*>
(
mask
);
uint32_t
*
m4
=
reinterpret_cast
<
uint32_t
*>
(
m
);
m4
[
0
]
=
mask4
[
i
];
float4
input4
=
in4
[
i
];
float4
res4
;
res4
.
x
=
input4
.
x
*
scale
*
static_cast
<
float
>
(
m
[
0
]);
res4
.
y
=
input4
.
y
*
scale
*
static_cast
<
float
>
(
m
[
1
]);
res4
.
z
=
input4
.
z
*
scale
*
static_cast
<
float
>
(
m
[
2
]);
res4
.
w
=
input4
.
w
*
scale
*
static_cast
<
float
>
(
m
[
3
]);
out4
[
i
]
=
res4
;
}
__global__
void
ls_dropout_bwd_kernel
(
const
int
total_count
,
const
float
ratio
,
__half
*
out
,
const
__half
*
in
,
const
uint8_t
*
__restrict__
mask
)
{
const
__half
scale
=
1.
f
/
(
1.
f
-
ratio
);
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
8
>=
total_count
)
return
;
float4
*
out4
=
reinterpret_cast
<
float4
*>
(
out
);
const
float4
*
vals_float4
=
reinterpret_cast
<
const
float4
*>
(
in
);
const
uint64_t
*
mask8
=
reinterpret_cast
<
const
uint64_t
*>
(
mask
);
uint8_t
m
[
8
];
uint64_t
*
m8
=
reinterpret_cast
<
uint64_t
*>
(
m
);
m8
[
0
]
=
mask8
[
i
];
float4
val_float4
=
vals_float4
[
i
];
float4
out_float4
;
__half2
*
val_half2
=
reinterpret_cast
<
__half2
*>
(
&
val_float4
);
__half2
*
out_half2
=
reinterpret_cast
<
__half2
*>
(
&
out_float4
);
__half2
scale_mask_1
=
__halves2half2
(
scale
*
__float2half
(
m
[
0
]),
scale
*
__float2half
(
m
[
1
]));
__half2
scale_mask_2
=
__halves2half2
(
scale
*
__float2half
(
m
[
2
]),
scale
*
__float2half
(
m
[
3
]));
__half2
scale_mask_3
=
__halves2half2
(
scale
*
__float2half
(
m
[
4
]),
scale
*
__float2half
(
m
[
5
]));
__half2
scale_mask_4
=
__halves2half2
(
scale
*
__float2half
(
m
[
6
]),
scale
*
__float2half
(
m
[
7
]));
out_half2
[
0
]
=
__hmul2
(
val_half2
[
0
],
scale_mask_1
);
out_half2
[
1
]
=
__hmul2
(
val_half2
[
1
],
scale_mask_2
);
out_half2
[
2
]
=
__hmul2
(
val_half2
[
2
],
scale_mask_3
);
out_half2
[
3
]
=
__hmul2
(
val_half2
[
3
],
scale_mask_4
);
out4
[
i
]
=
out_float4
;
}
template
<
>
void
launch_ls_dropout
<
float
>
(
float
*
out
,
const
float
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
,
bool
backward
)
{
int
grid_dim
=
total_count
>>
12
;
if
(
!
backward
)
{
ls_dropout_kernel
<<<
grid_dim
+
1
,
1024
,
0
,
stream
>>>
(
total_count
,
ratio
,
out
,
vals
,
mask
,
std
::
chrono
::
duration_cast
<
std
::
chrono
::
microseconds
>
(
std
::
chrono
::
system_clock
::
now
().
time_since_epoch
())
.
count
());
}
else
{
ls_dropout_bwd_kernel
<<<
grid_dim
+
1
,
1024
,
0
,
stream
>>>
(
total_count
,
ratio
,
out
,
vals
,
mask
);
}
}
template
<
>
void
launch_ls_dropout
<
__half
>
(
__half
*
out
,
const
__half
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
,
bool
backward
)
{
int
grid_dim
=
total_count
>>
13
;
if
(
!
backward
)
{
ls_dropout_kernel
<<<
grid_dim
+
1
,
1024
,
0
,
stream
>>>
(
total_count
,
ratio
,
out
,
vals
,
mask
,
std
::
chrono
::
duration_cast
<
std
::
chrono
::
microseconds
>
(
std
::
chrono
::
system_clock
::
now
().
time_since_epoch
())
.
count
());
}
else
{
ls_dropout_bwd_kernel
<<<
grid_dim
+
1
,
1024
,
0
,
stream
>>>
(
total_count
,
ratio
,
out
,
vals
,
mask
);
}
}
/**
* @brief fused bias, dropout, and residual at the end of Attention and FFN,
* store dropped position in mask, it's not in-place
*
* @thread
* gridDim.x = total_count / 1024
* blockDim.x = 1024
*
* @param total_count total elements
* @param ratio drop ratio
* @param out [batch_size, seq_len, hidden_size], float and __half
* @param in [batch_size, seq_len, hidden_size], float and __half
* @param mask [batch_size, seq_len, hidden_size], uint8 type
* @param bias [hidden_size], ffn bias
* @param residual [batch_size, seq_len, hidden_size], float and __half
* @param seed seed to curand
* @param hidden_size hidden size
* @return void
*/
__global__
void
ls_dropout_res_bias_kernel
(
const
int
total_count
,
const
float
ratio
,
float
*
__restrict__
out
,
const
float
*
__restrict__
in
,
uint8_t
*
__restrict__
mask
,
const
float
*
__restrict__
bias
,
const
float
*
__restrict__
residual
,
const
int
seed
,
const
int
hidden_size
)
{
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
4
>=
total_count
)
return
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
i
,
0
,
&
state
);
uint8_t
m
[
4
];
float4
*
out4
=
reinterpret_cast
<
float4
*>
(
out
);
const
float4
*
data4
=
reinterpret_cast
<
const
float4
*>
(
in
);
const
float4
*
residual4
=
reinterpret_cast
<
const
float4
*>
(
residual
);
const
float4
*
bias4
=
reinterpret_cast
<
const
float4
*>
(
bias
);
uint32_t
*
mask4
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
float4
rand
=
curand_uniform4
(
&
state
);
m
[
0
]
=
static_cast
<
uint8_t
>
(
rand
.
x
>
ratio
);
m
[
1
]
=
static_cast
<
uint8_t
>
(
rand
.
y
>
ratio
);
m
[
2
]
=
static_cast
<
uint8_t
>
(
rand
.
z
>
ratio
);
m
[
3
]
=
static_cast
<
uint8_t
>
(
rand
.
w
>
ratio
);
int
bias_i
=
i
%
(
hidden_size
>>
2
);
uint32_t
*
m4
=
reinterpret_cast
<
uint32_t
*>
(
m
);
mask4
[
i
]
=
m4
[
0
];
const
float4
input4
=
data4
[
i
];
const
float4
b4
=
__ldg
(
&
bias4
[
bias_i
]);
const
float4
res4
=
residual4
[
i
];
float4
output4
;
output4
.
x
=
(
input4
.
x
+
b4
.
x
)
*
scale
*
m
[
0
]
+
res4
.
x
;
output4
.
y
=
(
input4
.
y
+
b4
.
y
)
*
scale
*
m
[
1
]
+
res4
.
y
;
output4
.
z
=
(
input4
.
z
+
b4
.
z
)
*
scale
*
m
[
2
]
+
res4
.
z
;
output4
.
w
=
(
input4
.
w
+
b4
.
w
)
*
scale
*
m
[
3
]
+
res4
.
w
;
out4
[
i
]
=
output4
;
}
__global__
void
ls_dropout_res_bias_kernel
(
const
int
total_count
,
const
float
ratio
,
__half
*
__restrict__
out
,
const
__half
*
__restrict__
in
,
uint8_t
*
__restrict__
mask
,
const
__half
*
__restrict__
bias
,
const
__half
*
__restrict__
residual
,
const
int
seed
,
const
int
hidden_size
)
{
const
__half
scale
=
1.
/
(
1.
-
ratio
);
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
8
>=
total_count
)
return
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
i
,
0
,
&
state
);
const
float4
*
vals_float4
=
reinterpret_cast
<
const
float4
*>
(
in
);
float4
*
outs_float4
=
reinterpret_cast
<
float4
*>
(
out
);
const
float4
*
residual4
=
reinterpret_cast
<
const
float4
*>
(
residual
);
const
float4
*
bias4
=
reinterpret_cast
<
const
float4
*>
(
bias
);
uint64_t
*
mask8
=
reinterpret_cast
<
uint64_t
*>
(
mask
);
uint8_t
m
[
8
];
float4
rand
=
curand_uniform4
(
&
state
);
m
[
0
]
=
static_cast
<
uint8_t
>
(
rand
.
x
>
ratio
);
m
[
1
]
=
static_cast
<
uint8_t
>
(
rand
.
y
>
ratio
);
m
[
2
]
=
static_cast
<
uint8_t
>
(
rand
.
z
>
ratio
);
m
[
3
]
=
static_cast
<
uint8_t
>
(
rand
.
w
>
ratio
);
rand
=
curand_uniform4
(
&
state
);
m
[
4
]
=
static_cast
<
uint8_t
>
(
rand
.
x
>
ratio
);
m
[
5
]
=
static_cast
<
uint8_t
>
(
rand
.
y
>
ratio
);
m
[
6
]
=
static_cast
<
uint8_t
>
(
rand
.
z
>
ratio
);
m
[
7
]
=
static_cast
<
uint8_t
>
(
rand
.
w
>
ratio
);
uint64_t
*
m8
=
reinterpret_cast
<
uint64_t
*>
(
m
);
mask8
[
i
]
=
m8
[
0
];
int
bias_i
=
i
%
(
hidden_size
>>
3
);
float4
val_float4
=
vals_float4
[
i
];
const
float4
b4
=
__ldg
(
&
bias4
[
bias_i
]);
const
float4
res4
=
residual4
[
i
];
float4
out_float4
;
__half2
*
val_half2
=
reinterpret_cast
<
__half2
*>
(
&
val_float4
);
__half2
*
out_half2
=
reinterpret_cast
<
__half2
*>
(
&
out_float4
);
const
__half2
*
b_half2
=
reinterpret_cast
<
const
__half2
*>
(
&
b4
);
const
__half2
*
res_half2
=
reinterpret_cast
<
const
__half2
*>
(
&
res4
);
__half2
scale_mask_1
=
__halves2half2
(
scale
*
__float2half
(
m
[
0
]),
scale
*
__float2half
(
m
[
1
]));
__half2
scale_mask_2
=
__halves2half2
(
scale
*
__float2half
(
m
[
2
]),
scale
*
__float2half
(
m
[
3
]));
__half2
scale_mask_3
=
__halves2half2
(
scale
*
__float2half
(
m
[
4
]),
scale
*
__float2half
(
m
[
5
]));
__half2
scale_mask_4
=
__halves2half2
(
scale
*
__float2half
(
m
[
6
]),
scale
*
__float2half
(
m
[
7
]));
out_half2
[
0
]
=
__hfma2
(
__hadd2
(
val_half2
[
0
],
b_half2
[
0
]),
scale_mask_1
,
res_half2
[
0
]);
out_half2
[
1
]
=
__hfma2
(
__hadd2
(
val_half2
[
1
],
b_half2
[
1
]),
scale_mask_2
,
res_half2
[
1
]);
out_half2
[
2
]
=
__hfma2
(
__hadd2
(
val_half2
[
2
],
b_half2
[
2
]),
scale_mask_3
,
res_half2
[
2
]);
out_half2
[
3
]
=
__hfma2
(
__hadd2
(
val_half2
[
3
],
b_half2
[
3
]),
scale_mask_4
,
res_half2
[
3
]);
outs_float4
[
i
]
=
out_float4
;
}
template
<
>
void
launch_ls_dropout_res_bias
<
float
>
(
float
*
out
,
const
float
*
vals
,
uint8_t
*
mask
,
const
float
*
bias
,
const
float
*
residual
,
int
total_count
,
int
dim
,
float
ratio
,
cudaStream_t
stream
)
{
int
grid_dim
=
total_count
>>
12
;
ls_dropout_res_bias_kernel
<<<
grid_dim
+
1
,
1024
,
0
,
stream
>>>
(
total_count
,
ratio
,
out
,
vals
,
mask
,
bias
,
residual
,
std
::
chrono
::
duration_cast
<
std
::
chrono
::
microseconds
>
(
std
::
chrono
::
system_clock
::
now
().
time_since_epoch
())
.
count
(),
dim
);
}
template
<
>
void
launch_ls_dropout_res_bias
<
__half
>
(
__half
*
out
,
const
__half
*
vals
,
uint8_t
*
mask
,
const
__half
*
bias
,
const
__half
*
residual
,
int
total_count
,
int
dim
,
float
ratio
,
cudaStream_t
stream
)
{
int
grid_dim
=
total_count
>>
13
;
ls_dropout_res_bias_kernel
<<<
grid_dim
+
1
,
1024
,
0
,
stream
>>>
(
total_count
,
ratio
,
out
,
vals
,
mask
,
bias
,
residual
,
std
::
chrono
::
duration_cast
<
std
::
chrono
::
microseconds
>
(
std
::
chrono
::
system_clock
::
now
().
time_since_epoch
())
.
count
(),
dim
);
}
/**
* @brief fused bias and dropout backward at the end of Attention and FFN
*
* @thread
* gridDim.x = hidden_size / 8
* blockDim.x = 8
* blockDim.y = 1024 / 8 = 128
*
* @param row_size batch_size * seq_len
* @param ratio dropout ratio
* @param in_grad [batch_size, seq_len, hidden_size], input grad
* @param bias_grad [hidden_size], bias grad
* @param out_grad [batch_size, seq_len, hidden_size], output grad
* @param mask [batch_size, seq_len, hidden_size], dropout mask
* @param hidden_size
* @return void
*/
__global__
void
ls_dropout_bias_bwd_kernel
(
const
int
row_size
,
const
float
ratio
,
float
*
__restrict__
in_grad
,
float
*
__restrict__
bias_grad
,
const
float
*
__restrict__
out_grad
,
const
uint8_t
*
__restrict__
mask
,
const
int
hidden_size
)
{
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
// every block generate 8 bias result
__shared__
float
tile
[
8
][
129
];
#ifndef COLOSSAL_HIP
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
#endif
int
col_idx
=
flat_2dim
(
blockIdx
.
x
,
threadIdx
.
x
,
8
);
int
stride
=
hidden_size
*
128
;
float
local_sum
=
0
;
int
idx
=
flat_2dim
(
threadIdx
.
y
,
col_idx
,
hidden_size
);
for
(
int
r
=
threadIdx
.
y
;
r
<
row_size
;
r
+=
128
)
{
float
val
=
out_grad
[
idx
];
val
*=
scale
*
static_cast
<
float
>
(
mask
[
idx
]);
local_sum
+=
val
;
in_grad
[
idx
]
=
val
;
idx
+=
stride
;
}
tile
[
threadIdx
.
x
][
threadIdx
.
y
]
=
local_sum
;
__syncthreads
();
float
sum
=
0
;
int
tid
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
int
x
=
tid
>>
7
;
int
y
=
tid
&
(
127
);
if
(
y
<
32
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
sum
+=
tile
[
x
][
y
+
i
*
32
];
}
}
__syncthreads
();
#ifdef COLOSSAL_HIP
for
(
int
i
=
1
;
i
<
32
;
i
<<=
1
)
sum
+=
__shfl_down
(
sum
,
i
);
#else
for
(
int
i
=
1
;
i
<
32
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
#endif
if
(
y
==
0
)
tile
[
0
][
x
]
=
sum
;
__syncthreads
();
if
(
threadIdx
.
x
<
8
)
{
int
pos
=
flat_2dim
(
blockIdx
.
x
,
threadIdx
.
x
,
8
);
bias_grad
[
pos
]
=
tile
[
0
][
threadIdx
.
x
];
}
}
__global__
void
ls_dropout_bias_bwd_kernel
(
const
int
row_size
,
const
float
ratio
,
__half
*
__restrict__
in_grad
,
__half
*
__restrict__
bias_grad
,
const
__half
*
__restrict__
out_grad
,
const
uint8_t
*
__restrict__
mask
,
const
int
hidden_size
)
{
const
__half2
scale
=
__float2half2_rn
(
1.
f
/
(
1.
f
-
ratio
));
__shared__
__half2
tile
[
8
][
129
];
#ifndef COLOSSAL_HIP
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
#endif
__half2
*
in_grad2
=
reinterpret_cast
<
__half2
*>
(
in_grad
);
const
__half2
*
out_grad2
=
reinterpret_cast
<
const
__half2
*>
(
out_grad
);
__half2
*
bias_grad2
=
reinterpret_cast
<
__half2
*>
(
bias_grad
);
int
col_idx
=
flat_2dim
(
blockIdx
.
x
,
threadIdx
.
x
,
8
);
int
stride
=
hidden_size
*
128
;
__half2
local_sum
=
__float2half2_rn
(
0.
f
);
int
idx
=
flat_2dim
(
threadIdx
.
y
,
col_idx
,
hidden_size
);
for
(
int
r
=
threadIdx
.
y
;
r
<
row_size
;
r
+=
128
)
{
__half2
val
=
out_grad2
[
idx
];
__half2
m2
=
__floats2half2_rn
(
mask
[
2
*
idx
],
mask
[
2
*
idx
+
1
]);
val
*=
scale
*
m2
;
local_sum
+=
val
;
in_grad2
[
idx
]
=
val
;
idx
+=
stride
;
}
tile
[
threadIdx
.
x
][
threadIdx
.
y
]
=
local_sum
;
__syncthreads
();
__half2
sum
=
__float2half2_rn
(
0.
f
);
int
tid
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
int
x
=
tid
>>
7
;
int
y
=
tid
&
(
127
);
if
(
y
<
32
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
sum
+=
tile
[
x
][
y
+
i
*
32
];
}
}
__syncthreads
();
#ifdef COLOSSAL_HIP
float2
sum_f2
=
__half22float2
(
sum
);
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum_f2
.
x
+=
__shfl_down
(
sum_f2
.
x
,
i
);
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum_f2
.
y
+=
__shfl_down
(
sum_f2
.
y
,
i
);
sum
=
__float22half2_rn
(
sum_f2
);
#else
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
#endif
if
(
y
==
0
)
tile
[
0
][
x
]
=
sum
;
__syncthreads
();
if
(
threadIdx
.
x
<
8
)
{
int
pos
=
flat_2dim
(
blockIdx
.
x
,
threadIdx
.
x
,
8
);
bias_grad2
[
pos
]
=
tile
[
0
][
threadIdx
.
x
];
}
}
template
<
typename
T
>
void
launch_ls_dropout_bias_bwd
(
T
*
in_grad
,
T
*
bias_grad
,
const
T
*
out_grad
,
const
uint8_t
*
mask
,
int
row_size
,
int
dim
,
float
ratio
,
cudaStream_t
stream
)
{
dim3
grid_dim
((
dim
-
1
)
/
8
+
1
);
dim3
block_dim
(
8
,
128
);
ls_dropout_bias_bwd_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
row_size
,
ratio
,
in_grad
,
bias_grad
,
out_grad
,
mask
,
dim
);
}
template
<
>
void
launch_ls_dropout_bias_bwd
(
__half
*
in_grad
,
__half
*
bias_grad
,
const
__half
*
out_grad
,
const
uint8_t
*
mask
,
int
row_size
,
int
dim
,
float
ratio
,
cudaStream_t
stream
)
{
dim
>>=
1
;
dim3
grid_dim
((
dim
-
1
)
/
8
+
1
);
dim3
block_dim
(
8
,
128
);
ls_dropout_bias_bwd_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
row_size
,
ratio
,
in_grad
,
bias_grad
,
out_grad
,
mask
,
dim
);
}
template
void
launch_ls_dropout_bias_bwd
(
float
*
in_grad
,
float
*
bias_grad
,
const
float
*
out_grad
,
const
uint8_t
*
mask
,
int
row_size
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
/**
* @brief fused bias, activation, and dropout at the end of first ffn
*
* @thread
* gridDim.x = hidden_size / 8
* blockDim.x = 8
* blockDim.y = 1024 / 8 = 128
*
* @tparam act_type activation function, like kRelu, kGelu
* @param total_count total elements
* @param ratio drop ratio
* @param out [batch_size, seq_len, hidden_size], float and __half
* @param in [batch_size, seq_len, hidden_size], float and __half
* @param mask [batch_size, seq_len, hidden_size], uint8 type
* @param bias [hidden_size], ffn bias
* @param seed seed to curand
* @param hidden_size
* @return void
*/
template
<
ActivationType
act_type
>
__global__
void
ls_dropout_act_bias_kernel
(
const
int
total_count
,
const
float
ratio
,
float
*
__restrict__
out
,
const
float
*
__restrict__
in
,
uint8_t
*
__restrict__
mask
,
const
float
*
__restrict__
bias
,
const
int
seed
,
const
int
hidden_size
)
{
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
4
>=
total_count
)
return
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
i
,
0
,
&
state
);
uint8_t
m
[
4
];
float4
*
out4
=
reinterpret_cast
<
float4
*>
(
out
);
const
float4
*
data4
=
reinterpret_cast
<
const
float4
*>
(
in
);
const
float4
*
bias4
=
reinterpret_cast
<
const
float4
*>
(
bias
);
uint32_t
*
mask4
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
float4
rand
=
curand_uniform4
(
&
state
);
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
int
bias_i
=
i
%
(
hidden_size
>>
2
);
uint32_t
*
m4
=
reinterpret_cast
<
uint32_t
*>
(
m
);
mask4
[
i
]
=
m4
[
0
];
const
float4
input4
=
data4
[
i
];
const
float4
b4
=
__ldg
(
&
bias4
[
bias_i
]);
float4
output4
;
output4
.
x
=
activation_kernel
<
act_type
,
float
>
(
input4
.
x
+
b4
.
x
)
*
scale
*
m
[
0
];
output4
.
y
=
activation_kernel
<
act_type
,
float
>
(
input4
.
y
+
b4
.
y
)
*
scale
*
m
[
1
];
output4
.
z
=
activation_kernel
<
act_type
,
float
>
(
input4
.
z
+
b4
.
z
)
*
scale
*
m
[
2
];
output4
.
w
=
activation_kernel
<
act_type
,
float
>
(
input4
.
w
+
b4
.
w
)
*
scale
*
m
[
3
];
out4
[
i
]
=
output4
;
}
template
<
ActivationType
act_type
>
__global__
void
ls_dropout_act_bias_kernel
(
const
int
total_count
,
const
float
ratio
,
__half
*
__restrict__
out
,
const
__half
*
__restrict__
in
,
uint8_t
*
__restrict__
mask
,
const
__half
*
__restrict__
bias
,
const
int
seed
,
const
int
hidden_size
)
{
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
8
>=
total_count
)
return
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
i
,
0
,
&
state
);
const
float4
*
vals_float4
=
reinterpret_cast
<
const
float4
*>
(
in
);
float4
*
outs_float4
=
reinterpret_cast
<
float4
*>
(
out
);
const
float4
*
bias4
=
reinterpret_cast
<
const
float4
*>
(
bias
);
uint64_t
*
mask8
=
reinterpret_cast
<
uint64_t
*>
(
mask
);
uint8_t
m
[
8
];
float4
rand
=
curand_uniform4
(
&
state
);
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
rand
=
curand_uniform4
(
&
state
);
m
[
4
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
5
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
6
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
7
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
uint64_t
*
m8
=
reinterpret_cast
<
uint64_t
*>
(
m
);
mask8
[
i
]
=
*
m8
;
int
bias_i
=
i
%
(
hidden_size
>>
3
);
float4
val_float4
=
vals_float4
[
i
];
const
float4
b4
=
__ldg
(
&
bias4
[
bias_i
]);
float4
out_float4
;
__half2
*
val_half2
=
reinterpret_cast
<
__half2
*>
(
&
val_float4
);
__half2
*
out_half2
=
reinterpret_cast
<
__half2
*>
(
&
out_float4
);
const
__half2
*
b_half2
=
reinterpret_cast
<
const
__half2
*>
(
&
b4
);
__half2
scale_mask_1
=
__floats2half2_rn
(
scale
*
m
[
0
],
scale
*
m
[
1
]);
__half2
scale_mask_2
=
__floats2half2_rn
(
scale
*
m
[
2
],
scale
*
m
[
3
]);
__half2
scale_mask_3
=
__floats2half2_rn
(
scale
*
m
[
4
],
scale
*
m
[
5
]);
__half2
scale_mask_4
=
__floats2half2_rn
(
scale
*
m
[
6
],
scale
*
m
[
7
]);
out_half2
[
0
]
=
__hmul2
(
activation_kernel
<
act_type
,
__half2
>
(
__hadd2
(
val_half2
[
0
],
b_half2
[
0
])),
scale_mask_1
);
out_half2
[
1
]
=
__hmul2
(
activation_kernel
<
act_type
,
__half2
>
(
__hadd2
(
val_half2
[
1
],
b_half2
[
1
])),
scale_mask_2
);
out_half2
[
2
]
=
__hmul2
(
activation_kernel
<
act_type
,
__half2
>
(
__hadd2
(
val_half2
[
2
],
b_half2
[
2
])),
scale_mask_3
);
out_half2
[
3
]
=
__hmul2
(
activation_kernel
<
act_type
,
__half2
>
(
__hadd2
(
val_half2
[
3
],
b_half2
[
3
])),
scale_mask_4
);
outs_float4
[
i
]
=
out_float4
;
}
template
<
>
void
launch_ls_dropout_act_bias
<
ActivationType
::
kGelu
,
float
>
(
float
*
out
,
const
float
*
vals
,
uint8_t
*
mask
,
const
float
*
bias
,
int
total_count
,
int
dim
,
float
ratio
,
cudaStream_t
stream
)
{
int
grid_dim
=
total_count
>>
10
;
ls_dropout_act_bias_kernel
<
ActivationType
::
kGelu
>
<<<
grid_dim
+
1
,
256
,
0
,
stream
>>>
(
total_count
,
ratio
,
out
,
vals
,
mask
,
bias
,
std
::
chrono
::
duration_cast
<
std
::
chrono
::
microseconds
>
(
std
::
chrono
::
system_clock
::
now
().
time_since_epoch
())
.
count
(),
dim
);
}
template
<
>
void
launch_ls_dropout_act_bias
<
ActivationType
::
kGelu
,
__half
>
(
__half
*
out
,
const
__half
*
vals
,
uint8_t
*
mask
,
const
__half
*
bias
,
int
total_count
,
int
dim
,
float
ratio
,
cudaStream_t
stream
)
{
int
grid_dim
=
total_count
>>
11
;
ls_dropout_act_bias_kernel
<
ActivationType
::
kGelu
>
<<<
grid_dim
+
1
,
256
,
0
,
stream
>>>
(
total_count
,
ratio
,
out
,
vals
,
mask
,
bias
,
std
::
chrono
::
duration_cast
<
std
::
chrono
::
microseconds
>
(
std
::
chrono
::
system_clock
::
now
().
time_since_epoch
())
.
count
(),
dim
);
}
template
<
>
void
launch_ls_dropout_act_bias
<
ActivationType
::
kRelu
,
float
>
(
float
*
out
,
const
float
*
vals
,
uint8_t
*
mask
,
const
float
*
bias
,
int
total_count
,
int
dim
,
float
ratio
,
cudaStream_t
stream
)
{
int
grid_dim
=
total_count
>>
10
;
ls_dropout_act_bias_kernel
<
ActivationType
::
kRelu
>
<<<
grid_dim
+
1
,
256
,
0
,
stream
>>>
(
total_count
,
ratio
,
out
,
vals
,
mask
,
bias
,
std
::
chrono
::
duration_cast
<
std
::
chrono
::
microseconds
>
(
std
::
chrono
::
system_clock
::
now
().
time_since_epoch
())
.
count
(),
dim
);
}
template
<
>
void
launch_ls_dropout_act_bias
<
ActivationType
::
kRelu
,
__half
>
(
__half
*
out
,
const
__half
*
vals
,
uint8_t
*
mask
,
const
__half
*
bias
,
int
total_count
,
int
dim
,
float
ratio
,
cudaStream_t
stream
)
{
int
grid_dim
=
total_count
>>
11
;
ls_dropout_act_bias_kernel
<
ActivationType
::
kRelu
>
<<<
grid_dim
+
1
,
256
,
0
,
stream
>>>
(
total_count
,
ratio
,
out
,
vals
,
mask
,
bias
,
std
::
chrono
::
duration_cast
<
std
::
chrono
::
microseconds
>
(
std
::
chrono
::
system_clock
::
now
().
time_since_epoch
())
.
count
(),
dim
);
}
/**
* @brief fused bias, activation, and dropout backward
*
* @thread
* gridDim.x = total_count / 1024
* blockDim.x = 1024
*
* @tparam act_type kRelu
* @param row_size batch_size * seq_len
* @param ratio dropout ratio
* @param in_grad [batch_size, seq_len, hidden_size], input grad
* @param bias_grad [hidden_size], bias grad
* @param out_grad [batch_size, seq_len, hidden_size], output grad
* @param mask [batch_size, seq_len, hidden_size], dropout mask
* @param hidden_size
* @return void
*/
template
<
ActivationType
act_type
,
typename
T
>
__global__
void
ls_dropout_act_bias_bwd_kernel
(
const
int
row_size
,
const
float
ratio
,
T
*
in_grad
,
T
*
__restrict__
bias_grad
,
const
T
*
__restrict__
input
,
const
T
*
__restrict__
bias
,
const
T
*
out_grad
,
const
uint8_t
*
__restrict__
mask
,
const
int
hidden_size
)
{
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
__shared__
float
tile
[
WARP_SIZE
][
WARP_SIZE
+
1
];
#ifndef COLOSSAL_HIP
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
#endif
int
col_idx
=
flat_2dim
(
blockIdx
.
x
,
threadIdx
.
x
,
WARP_SIZE
);
int
stride
=
hidden_size
*
WARP_SIZE
;
float
local_sum
=
0
;
int
idx
=
flat_2dim
(
threadIdx
.
y
,
col_idx
,
hidden_size
);
if
(
col_idx
<
hidden_size
)
{
for
(
int
r
=
threadIdx
.
y
;
r
<
row_size
;
r
+=
WARP_SIZE
)
{
float
val
=
out_grad
[
idx
];
float
in
=
input
[
idx
];
float
b
=
bias
[
idx
%
hidden_size
];
val
=
activation_bwd_kernel
<
act_type
,
float
>
(
val
*
scale
*
static_cast
<
float
>
(
mask
[
idx
]),
in
+
b
);
local_sum
+=
val
;
in_grad
[
idx
]
=
val
;
idx
+=
stride
;
}
}
tile
[
threadIdx
.
x
][
threadIdx
.
y
]
=
local_sum
;
__syncthreads
();
float
sum
=
tile
[
threadIdx
.
y
][
threadIdx
.
x
];
__syncthreads
();
#ifdef COLOSSAL_HIP
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum
+=
__shfl_down
(
sum
,
i
);
#else
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
#endif
if
(
threadIdx
.
x
==
0
)
tile
[
0
][
threadIdx
.
y
]
=
sum
;
__syncthreads
();
if
(
threadIdx
.
y
==
0
)
{
int
pos
=
flat_2dim
(
blockIdx
.
x
,
threadIdx
.
x
,
WARP_SIZE
);
bias_grad
[
pos
]
=
tile
[
0
][
threadIdx
.
x
];
}
}
// @brief fused bias, activation, and dropout backward
// It is deprecated for precision reason. Keep it for future optimization.
//
// template <ActivationType act_type>
// __global__ void ls_dropout_act_bias_bwd_kernel(
// const int row_size, const float ratio, __half * in_grad,
// __half *__restrict__ bias_grad, const __half *__restrict__ input, const
// __half *__restrict__ bias, const __half * out_grad, const uint8_t
// *__restrict__ mask, const int hidden_size) {
// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio));
// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1];
// cg::thread_block b = cg::this_thread_block();
// cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad);
// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad);
// const __half2 *out_grad2 = reinterpret_cast<const __half2 *>(out_grad);
// const __half2 *input2 = reinterpret_cast<const __half2 *>(input);
// const __half2 *bias2 = reinterpret_cast<const __half2 *>(bias);
// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
// int stride = hidden_size * WARP_SIZE;
// __half2 local_sum = __float2half2_rn(0.f);
// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
// if (col_idx < hidden_size) {
// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) {
// __half2 val = out_grad2[idx];
// __half2 in2 = input2[idx];
// __half2 b2 = bias2[idx % hidden_size ];
// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]);
// val = activation_bwd_kernel<ActivationType::kRelu, __half2>(val * scale
// *
// m2,
// in2+b2);
// local_sum += val;
// in_grad2[idx] = val;
// idx += stride;
// }
// }
// tile[threadIdx.x][threadIdx.y] = local_sum;
// __syncthreads();
// __half2 sum = tile[threadIdx.y][threadIdx.x];
// __syncthreads();
// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
// __syncthreads();
// if (threadIdx.y == 0) {
// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
// bias_grad2[pos] = tile[0][threadIdx.x];
// }
// }
template
<
ActivationType
act_type
,
typename
T
>
void
launch_ls_dropout_act_bias_bwd
(
T
*
in_grad
,
T
*
bias_grad
,
const
T
*
input
,
const
T
*
bias
,
const
T
*
out_grad
,
const
uint8_t
*
mask
,
int
row_size
,
int
dim
,
float
ratio
,
cudaStream_t
stream
)
{
dim3
grid_dim
((
dim
-
1
)
/
WARP_SIZE
+
1
);
dim3
block_dim
(
WARP_SIZE
,
WARP_SIZE
);
ls_dropout_act_bias_bwd_kernel
<
act_type
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
row_size
,
ratio
,
in_grad
,
bias_grad
,
input
,
bias
,
out_grad
,
mask
,
dim
);
}
// template <>
// void launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, __half>(
// __half *in_grad, __half *bias_grad,const __half *input, const __half
// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int
// dim, float ratio, cudaStream_t stream) {
// dim >>= 1;
// dim3 grid_dim((dim - 1) / WARP_SIZE + 1);
// dim3 block_dim(WARP_SIZE, WARP_SIZE);
// ls_dropout_act_bias_bwd_kernel<ActivationType::kRelu>
// <<<grid_dim, block_dim, 0, stream>>>(row_size, ratio, in_grad,
// bias_grad,
// input, bias,out_grad, mask, dim);
// }
template
void
launch_ls_dropout_act_bias_bwd
<
ActivationType
::
kRelu
,
float
>(
float
*
in_grad
,
float
*
bias_grad
,
const
float
*
input
,
const
float
*
bias
,
const
float
*
out_grad
,
const
uint8_t
*
mask
,
int
row_size
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
template
void
launch_ls_dropout_act_bias_bwd
<
ActivationType
::
kRelu
,
__half
>(
__half
*
in_grad
,
__half
*
bias_grad
,
const
__half
*
input
,
const
__half
*
bias
,
const
__half
*
out_grad
,
const
uint8_t
*
mask
,
int
row_size
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
template
void
launch_ls_dropout_act_bias_bwd
<
ActivationType
::
kGelu
,
float
>(
float
*
in_grad
,
float
*
bias_grad
,
const
float
*
input
,
const
float
*
bias
,
const
float
*
out_grad
,
const
uint8_t
*
mask
,
int
row_size
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
template
void
launch_ls_dropout_act_bias_bwd
<
ActivationType
::
kGelu
,
__half
>(
__half
*
in_grad
,
__half
*
bias_grad
,
const
__half
*
input
,
const
__half
*
bias
,
const
__half
*
out_grad
,
const
uint8_t
*
mask
,
int
row_size
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu
0 → 100644
View file @
08f2920e
#include "kernels.h"
#ifndef COLOSSAL_HIP
#include <cooperative_groups.h>
namespace
cg
=
cooperative_groups
;
#endif
#include "kernels.h"
/**
@brief: fuse_transpose_bias
Calculate the sum of elements in each column of the matrix.
@thread
gridDim.x = ceil(cols / WARP_SIZE)
blockDim.x = WARP_SIZE
blockDim.y = WARP_SIZE
@param
inp: [rows, cols]
out: [cols]
rows: the number of rows in the matrix
cols: the number of cols in the matrix
*/
template
<
typename
T
>
__global__
void
column_sum_reduce
(
const
T
*
__restrict__
inp
,
T
*
__restrict__
out
,
int
rows
,
int
cols
)
{
__shared__
float
tile
[
WARP_SIZE
][
WARP_SIZE
];
#ifndef COLOSSAL_HIP
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
#endif
int
idx
=
flat_2dim
(
blockIdx
.
x
,
threadIdx
.
x
,
WARP_SIZE
);
int
y_stride
=
cols
*
WARP_SIZE
;
float
localSum
=
0
;
// Loop across matrix row
// TODO: optimize to log complexity
if
(
idx
<
cols
)
{
int
offset
=
flat_2dim
(
threadIdx
.
y
,
idx
,
cols
);
for
(
int
r
=
threadIdx
.
y
;
r
<
rows
;
r
+=
WARP_SIZE
)
{
localSum
+=
(
float
)
inp
[
offset
];
offset
+=
y_stride
;
}
}
// The sum of a row in tile is equal to the sum of a col in original matrix
tile
[
threadIdx
.
x
][
threadIdx
.
y
]
=
localSum
;
__syncthreads
();
// Sum the shared buffer.
// The change of threadIdx.x is continuous
float
sum
=
tile
[
threadIdx
.
y
][
threadIdx
.
x
];
__syncthreads
();
// Calculate the sum of a row in tile
#ifdef COLOSSAL_HIP
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum
+=
__shfl_down
(
sum
,
i
);
#else
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
#endif
if
(
threadIdx
.
x
==
0
)
{
int
pos
=
flat_2dim
(
blockIdx
.
x
,
threadIdx
.
y
,
WARP_SIZE
);
if
(
pos
<
cols
)
out
[
pos
]
=
sum
;
}
}
// [r, c] -> [c]
template
<
>
void
launch_fuse_transpose_bias_kernel
<
float
>
(
const
float
*
inp
,
float
*
out
,
int
rows
,
int
cols
,
cudaStream_t
stream
)
{
dim3
grid_dim
((
cols
-
1
)
/
WARP_SIZE
+
1
);
dim3
block_dim
(
WARP_SIZE
,
WARP_SIZE
);
column_sum_reduce
<
float
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
inp
,
out
,
rows
,
cols
);
}
template
<
>
void
launch_fuse_transpose_bias_kernel
<
__half
>
(
const
__half
*
inp
,
__half
*
out
,
int
rows
,
int
cols
,
cudaStream_t
stream
)
{
dim3
grid_dim
((
cols
-
1
)
/
WARP_SIZE
+
1
);
dim3
block_dim
(
WARP_SIZE
,
WARP_SIZE
);
column_sum_reduce
<
__half
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
inp
,
out
,
rows
,
cols
);
}
/**
@brief: fused_add2
Add two matrix inp1 and inp2 to out.
@thread
gridDim.x = batch_size * seq_len
blockDim.x = min(hidden_dim, MAX_THREADS)
@param
inp1: [batch_size, seq_len, hidden_dim]
inp2: [batch_size, seq_len, hidden_dim]
out: [batch_size, seq_len, hidden_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
*/
template
<
typename
T
>
__global__
void
fused_add2_kernel
(
T
*
out
,
const
T
*
inp1
,
const
T
*
inp2
,
int
hidden_dim
);
template
<
>
__global__
void
fused_add2_kernel
<
float
>
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
int
hidden_dim
)
{
int
row_id
=
blockIdx
.
x
;
int
offset
=
flat_2dim
(
row_id
,
0
,
hidden_dim
);
const
float4
*
inp1_4
=
reinterpret_cast
<
const
float4
*>
(
inp1
);
const
float4
*
inp2_4
=
reinterpret_cast
<
const
float4
*>
(
inp2
);
float4
*
out_4
=
reinterpret_cast
<
float4
*>
(
out
);
float4
vinp1
;
float4
vinp2
;
float4
val
;
for
(
std
::
size_t
i
=
threadIdx
.
x
;
i
<
hidden_dim
;
i
+=
blockDim
.
x
)
{
vinp1
=
inp1_4
[
offset
+
i
];
vinp2
=
inp2_4
[
offset
+
i
];
val
.
x
=
vinp1
.
x
+
vinp2
.
x
;
val
.
y
=
vinp1
.
y
+
vinp2
.
y
;
val
.
z
=
vinp1
.
z
+
vinp2
.
z
;
val
.
w
=
vinp1
.
w
+
vinp2
.
w
;
out_4
[
offset
+
i
]
=
val
;
}
}
template
<
>
__global__
void
fused_add2_kernel
<
__half
>
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
int
hidden_dim
)
{
int
row_id
=
blockIdx
.
x
;
int
offset
=
flat_2dim
(
row_id
,
0
,
hidden_dim
);
const
float4
*
inp1_4
=
reinterpret_cast
<
const
float4
*>
(
inp1
);
const
float4
*
inp2_4
=
reinterpret_cast
<
const
float4
*>
(
inp2
);
float4
*
out_4
=
reinterpret_cast
<
float4
*>
(
out
);
float4
vinp1
;
float4
vinp2
;
float4
val
;
__half2
*
h2_inp1
=
reinterpret_cast
<
__half2
*>
(
&
vinp1
);
__half2
*
h2_inp2
=
reinterpret_cast
<
__half2
*>
(
&
vinp2
);
__half2
*
h2_val
=
reinterpret_cast
<
__half2
*>
(
&
val
);
for
(
std
::
size_t
i
=
threadIdx
.
x
;
i
<
hidden_dim
;
i
+=
blockDim
.
x
)
{
vinp1
=
inp1_4
[
offset
+
i
];
vinp2
=
inp2_4
[
offset
+
i
];
h2_val
[
0
]
=
__hadd2
(
h2_inp1
[
0
],
h2_inp2
[
0
]);
h2_val
[
1
]
=
__hadd2
(
h2_inp1
[
1
],
h2_inp2
[
1
]);
h2_val
[
2
]
=
__hadd2
(
h2_inp1
[
2
],
h2_inp2
[
2
]);
h2_val
[
3
]
=
__hadd2
(
h2_inp1
[
3
],
h2_inp2
[
3
]);
out_4
[
offset
+
i
]
=
val
;
}
}
//[b, s, h] -> [b, s, h]
template
<
>
void
launch_fused_add2
<
float
>
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
int
batch_size
,
int
seq_len
,
int
hidden_dim
,
cudaStream_t
&
stream
)
{
hidden_dim
>>=
2
;
dim3
grid_dim
(
batch_size
*
seq_len
);
dim3
block_dim
(
min
(
hidden_dim
,
MAX_THREADS
));
fused_add2_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
inp1
,
inp2
,
hidden_dim
);
}
template
<
>
void
launch_fused_add2
<
__half
>
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
int
batch_size
,
int
seq_len
,
int
hidden_dim
,
cudaStream_t
&
stream
)
{
hidden_dim
>>=
3
;
dim3
grid_dim
(
batch_size
*
seq_len
);
dim3
block_dim
(
min
(
hidden_dim
,
MAX_THREADS
));
fused_add2_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
inp1
,
inp2
,
hidden_dim
);
}
template
<
typename
T
>
__global__
void
kernel_concat3_dim1
(
const
T
*
inp1
,
const
T
*
inp2
,
T
*
output
,
int
sz0
,
int
sz2
,
int
sz1_1
,
int
sz1_2
)
{
int
nele
=
sz0
*
sz2
*
(
sz1_1
+
sz1_2
);
int
idx
=
flat_2dim
(
blockIdx
.
x
,
threadIdx
.
x
,
blockDim
.
x
);
if
(
idx
>=
nele
)
{
return
;
}
float4
*
dst_ptr
=
(
float4
*
)
output
+
idx
;
int
idx2
=
idx
%
sz2
;
idx
=
idx
/
sz2
;
int
idx1
=
idx
%
(
sz1_1
+
sz1_2
);
int
idx0
=
idx
/
(
sz1_1
+
sz1_2
);
float4
*
src_ptr
=
nullptr
;
int
sz1
=
0
;
if
(
idx1
<
sz1_1
)
{
sz1
=
sz1_1
;
src_ptr
=
(
float4
*
)
inp1
;
}
else
{
idx1
-=
sz1_1
;
sz1
=
sz1_2
;
src_ptr
=
(
float4
*
)
inp2
;
}
src_ptr
+=
flat_3dim
(
idx0
,
idx1
,
idx2
,
sz1
,
sz2
);
dst_ptr
[
0
]
=
src_ptr
[
0
];
}
template
<
>
void
launch_concat3_dim1
<
float
>
(
const
float
*
inp1
,
const
float
*
inp2
,
float
*
output
,
int
sz0
,
int
sz2
,
int
sz1_1
,
int
sz1_2
,
cudaStream_t
stream
)
{
sz2
>>=
2
;
int
nele
=
sz0
*
sz2
*
(
sz1_1
+
sz1_2
);
int
nblock
=
(
nele
+
MAX_THREADS
-
1
)
/
MAX_THREADS
;
kernel_concat3_dim1
<<<
nblock
,
MAX_THREADS
,
0
,
stream
>>>
(
inp1
,
inp2
,
output
,
sz0
,
sz2
,
sz1_1
,
sz1_2
);
}
template
<
>
void
launch_concat3_dim1
<
__half
>
(
const
__half
*
inp1
,
const
__half
*
inp2
,
__half
*
output
,
int
sz0
,
int
sz2
,
int
sz1_1
,
int
sz1_2
,
cudaStream_t
stream
)
{
sz2
>>=
3
;
int
nele
=
sz0
*
sz2
*
(
sz1_1
+
sz1_2
);
int
nblock
=
(
nele
+
MAX_THREADS
-
1
)
/
MAX_THREADS
;
kernel_concat3_dim1
<<<
nblock
,
MAX_THREADS
,
0
,
stream
>>>
(
inp1
,
inp2
,
output
,
sz0
,
sz2
,
sz1_1
,
sz1_2
);
}
Prev
1
…
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment