Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
69251362
Commit
69251362
authored
May 11, 2020
by
rohithkrn
Browse files
enable multi tensor extension for bfloat16
parent
cec08a41
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
122 additions
and
20 deletions
+122
-20
apex/amp/_process_optimizer.py
apex/amp/_process_optimizer.py
+2
-2
apex/amp/scaler.py
apex/amp/scaler.py
+1
-1
csrc/multi_tensor_adam.cu
csrc/multi_tensor_adam.cu
+1
-1
csrc/multi_tensor_axpby_kernel.cu
csrc/multi_tensor_axpby_kernel.cu
+3
-3
csrc/multi_tensor_l2norm_kernel.cu
csrc/multi_tensor_l2norm_kernel.cu
+3
-3
csrc/multi_tensor_lamb.cu
csrc/multi_tensor_lamb.cu
+2
-2
csrc/multi_tensor_lamb_stage_1.cu
csrc/multi_tensor_lamb_stage_1.cu
+3
-3
csrc/multi_tensor_lamb_stage_2.cu
csrc/multi_tensor_lamb_stage_2.cu
+2
-2
csrc/multi_tensor_novograd.cu
csrc/multi_tensor_novograd.cu
+1
-1
csrc/multi_tensor_scale_kernel.cu
csrc/multi_tensor_scale_kernel.cu
+2
-2
csrc/multi_tensor_sgd_kernel.cu
csrc/multi_tensor_sgd_kernel.cu
+42
-0
csrc/type_shim.h
csrc/type_shim.h
+60
-0
No files found.
apex/amp/_process_optimizer.py
View file @
69251362
...
...
@@ -13,7 +13,7 @@ class AmpOptimizerState(object):
def
_master_params_to_model_params
(
self
):
stash
=
self
.
_amp_stash
if
multi_tensor_applier
.
available
and
_amp_state
.
opt_properties
.
opt_level
not
in
{
"O4"
,
"O5"
}
:
if
multi_tensor_applier
.
available
:
if
len
(
stash
.
all_fp16_params
)
>
0
:
multi_tensor_applier
(
stash
.
multi_tensor_scale
,
...
...
@@ -337,7 +337,7 @@ def _process_optimizer(optimizer, properties):
raise
RuntimeError
(
"Incoming optimizer already has {} defined."
.
format
(
name
))
# TODO: Centralize exposure and import error checking for the C backend.
if
multi_tensor_applier
.
available
and
not
properties
.
opt_level
in
{
"O4"
,
"O5"
}
:
if
multi_tensor_applier
.
available
:
import
amp_C
optimizer
.
_amp_stash
.
multi_tensor_scale
=
amp_C
.
multi_tensor_scale
optimizer
.
_amp_stash
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
...
...
apex/amp/scaler.py
View file @
69251362
...
...
@@ -63,7 +63,7 @@ class LossScaler(object):
self
.
_unskipped
=
0
self
.
_has_overflow
=
False
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
if
multi_tensor_applier
.
available
and
_amp_state
.
opt_properties
.
opt_level
not
in
{
"O4"
,
"O5"
}
:
if
multi_tensor_applier
.
available
:
import
amp_C
LossScaler
.
has_fused_kernel
=
multi_tensor_applier
.
available
LossScaler
.
multi_tensor_scale_cuda
=
amp_C
.
multi_tensor_scale
...
...
csrc/multi_tensor_adam.cu
View file @
69251362
...
...
@@ -149,7 +149,7 @@ void multi_tensor_adam_cuda(
}
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF
(
DISPATCH_DOUBLE_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"adam"
,
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
...
...
csrc/multi_tensor_axpby_kernel.cu
View file @
69251362
...
...
@@ -138,9 +138,9 @@ void multi_tensor_axpby_cuda(
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
2
][
0
].
scalar_type
(),
2
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
2
][
0
].
scalar_type
(),
2
,
"multi_tensor_axpby_cuda"
,
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
...
...
csrc/multi_tensor_l2norm_kernel.cu
View file @
69251362
...
...
@@ -322,7 +322,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
ret_per_tensor
=
at
::
empty
({
0
},
float_options
);
}
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_l2norm_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_l2norm_cuda"
,
multi_tensor_apply
<
1
>
(
BLOCK_SIZE
,
chunk_size
,
...
...
@@ -391,7 +391,7 @@ void multi_tensor_norm_out_cuda(
output_per_tensor
=
at
::
zeros
({
ntensors
*
max_chunks_per_tensor
},
float_options
);
if
(
norm_type
==
0
)
{
DISPATCH_FLOAT_AND_HALF
(
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_maxnorm_cuda"
,
multi_tensor_apply
<
1
>
(
BLOCK_SIZE
,
...
...
@@ -405,7 +405,7 @@ void multi_tensor_norm_out_cuda(
max_chunks_per_tensor
);)
}
else
{
DISPATCH_FLOAT_AND_HALF
(
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_l2norm_cuda"
,
multi_tensor_apply
<
1
>
(
BLOCK_SIZE
,
...
...
csrc/multi_tensor_lamb.cu
View file @
69251362
...
...
@@ -363,7 +363,7 @@ void multi_tensor_lamb_cuda(
// We now in-place modify grad to store update before compute its norm
// Generally this is not a issue since people modify grad in step() method all the time
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
...
...
@@ -386,7 +386,7 @@ void multi_tensor_lamb_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_param_list
(
tensor_lists
.
begin
(),
tensor_lists
.
begin
()
+
2
);
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
chunk_size
,
...
...
csrc/multi_tensor_lamb_stage_1.cu
View file @
69251362
...
...
@@ -127,9 +127,9 @@ void multi_tensor_lamb_stage1_cuda(
float
next_step
=
float
(
step
+
1
);
float
beta1_correction
=
1.0
f
-
std
::
pow
(
beta1
,
next_step
);
float
beta2_correction
=
1.0
f
-
std
::
pow
(
beta2
,
next_step
);
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
4
][
0
].
scalar_type
(),
2
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
4
][
0
].
scalar_type
(),
2
,
"lamb_stage_1"
,
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
...
...
csrc/multi_tensor_lamb_stage_2.cu
View file @
69251362
...
...
@@ -91,8 +91,8 @@ void multi_tensor_lamb_stage2_cuda(
{
using
namespace
at
;
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_2"
,
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
chunk_size
,
...
...
csrc/multi_tensor_novograd.cu
View file @
69251362
...
...
@@ -164,7 +164,7 @@ void multi_tensor_novograd_cuda(
multi_tensor_norm_out_cuda
(
chunk_size
,
noop_flag
,
grad_list
,
grad_norms
,
beta2
,
(
1.0
f
-
beta2
),
norm_type
);
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF
(
DISPATCH_DOUBLE_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"novograd"
,
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
...
...
csrc/multi_tensor_scale_kernel.cu
View file @
69251362
...
...
@@ -121,8 +121,8 @@ void multi_tensor_scale_cuda(
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_scale_cuda"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_scale_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_scale_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_scale_cuda"
,
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
chunk_size
,
...
...
csrc/multi_tensor_sgd_kernel.cu
View file @
69251362
...
...
@@ -166,6 +166,8 @@ void multi_tensor_sgd_cuda(
// 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// 5. bfp16, bfp16, bfp16, No
// 6. bfp16, fp32, fp32, Yes
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
...
...
@@ -268,6 +270,46 @@ void multi_tensor_sgd_cuda(
wd_after_momentum
,
scale
);
}
// Case 5. bfp16, bfp16, bfp16, No
if
(
grad_type
==
at
::
ScalarType
::
BFloat16
&&
weight_type
==
at
::
ScalarType
::
BFloat16
&&
num_tensors
==
3
)
{
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
3
,
at
::
BFloat16
,
at
::
BFloat16
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
// Case 6. bfp16, fp32, fp32, Yes
else
if
(
grad_type
==
at
::
ScalarType
::
BFloat16
&&
weight_type
==
at
::
ScalarType
::
Float
&&
num_tensors
==
4
)
{
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
4
,
at
::
BFloat16
,
float
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
else
{
AT_ERROR
(
"multi_tensor_sgd only supports some combinations of gradient & weight types. Given: "
,
...
...
csrc/type_shim.h
View file @
69251362
...
...
@@ -79,6 +79,66 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
// TODO: We might have come up with an optimal set of dispatch macros by
// changing the signature to have an integer suffix of number of types
// to dispatch for as defined in upstream (e.g AT_DISPATCH_FLOATING_TYPES_AND2)
// Refactor once all the extension ops are enabled.
#define DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment