Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
69251362
Commit
69251362
authored
May 11, 2020
by
rohithkrn
Browse files
enable multi tensor extension for bfloat16
parent
cec08a41
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
122 additions
and
20 deletions
+122
-20
apex/amp/_process_optimizer.py
apex/amp/_process_optimizer.py
+2
-2
apex/amp/scaler.py
apex/amp/scaler.py
+1
-1
csrc/multi_tensor_adam.cu
csrc/multi_tensor_adam.cu
+1
-1
csrc/multi_tensor_axpby_kernel.cu
csrc/multi_tensor_axpby_kernel.cu
+3
-3
csrc/multi_tensor_l2norm_kernel.cu
csrc/multi_tensor_l2norm_kernel.cu
+3
-3
csrc/multi_tensor_lamb.cu
csrc/multi_tensor_lamb.cu
+2
-2
csrc/multi_tensor_lamb_stage_1.cu
csrc/multi_tensor_lamb_stage_1.cu
+3
-3
csrc/multi_tensor_lamb_stage_2.cu
csrc/multi_tensor_lamb_stage_2.cu
+2
-2
csrc/multi_tensor_novograd.cu
csrc/multi_tensor_novograd.cu
+1
-1
csrc/multi_tensor_scale_kernel.cu
csrc/multi_tensor_scale_kernel.cu
+2
-2
csrc/multi_tensor_sgd_kernel.cu
csrc/multi_tensor_sgd_kernel.cu
+42
-0
csrc/type_shim.h
csrc/type_shim.h
+60
-0
No files found.
apex/amp/_process_optimizer.py
View file @
69251362
...
@@ -13,7 +13,7 @@ class AmpOptimizerState(object):
...
@@ -13,7 +13,7 @@ class AmpOptimizerState(object):
def
_master_params_to_model_params
(
self
):
def
_master_params_to_model_params
(
self
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
if
multi_tensor_applier
.
available
and
_amp_state
.
opt_properties
.
opt_level
not
in
{
"O4"
,
"O5"
}
:
if
multi_tensor_applier
.
available
:
if
len
(
stash
.
all_fp16_params
)
>
0
:
if
len
(
stash
.
all_fp16_params
)
>
0
:
multi_tensor_applier
(
multi_tensor_applier
(
stash
.
multi_tensor_scale
,
stash
.
multi_tensor_scale
,
...
@@ -337,7 +337,7 @@ def _process_optimizer(optimizer, properties):
...
@@ -337,7 +337,7 @@ def _process_optimizer(optimizer, properties):
raise
RuntimeError
(
"Incoming optimizer already has {} defined."
.
format
(
name
))
raise
RuntimeError
(
"Incoming optimizer already has {} defined."
.
format
(
name
))
# TODO: Centralize exposure and import error checking for the C backend.
# TODO: Centralize exposure and import error checking for the C backend.
if
multi_tensor_applier
.
available
and
not
properties
.
opt_level
in
{
"O4"
,
"O5"
}
:
if
multi_tensor_applier
.
available
:
import
amp_C
import
amp_C
optimizer
.
_amp_stash
.
multi_tensor_scale
=
amp_C
.
multi_tensor_scale
optimizer
.
_amp_stash
.
multi_tensor_scale
=
amp_C
.
multi_tensor_scale
optimizer
.
_amp_stash
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
optimizer
.
_amp_stash
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
...
...
apex/amp/scaler.py
View file @
69251362
...
@@ -63,7 +63,7 @@ class LossScaler(object):
...
@@ -63,7 +63,7 @@ class LossScaler(object):
self
.
_unskipped
=
0
self
.
_unskipped
=
0
self
.
_has_overflow
=
False
self
.
_has_overflow
=
False
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
if
multi_tensor_applier
.
available
and
_amp_state
.
opt_properties
.
opt_level
not
in
{
"O4"
,
"O5"
}
:
if
multi_tensor_applier
.
available
:
import
amp_C
import
amp_C
LossScaler
.
has_fused_kernel
=
multi_tensor_applier
.
available
LossScaler
.
has_fused_kernel
=
multi_tensor_applier
.
available
LossScaler
.
multi_tensor_scale_cuda
=
amp_C
.
multi_tensor_scale
LossScaler
.
multi_tensor_scale_cuda
=
amp_C
.
multi_tensor_scale
...
...
csrc/multi_tensor_adam.cu
View file @
69251362
...
@@ -149,7 +149,7 @@ void multi_tensor_adam_cuda(
...
@@ -149,7 +149,7 @@ void multi_tensor_adam_cuda(
}
}
// Assume single type across p,g,m1,m2 now
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF
(
DISPATCH_DOUBLE_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"adam"
,
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"adam"
,
multi_tensor_apply
<
4
>
(
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
...
...
csrc/multi_tensor_axpby_kernel.cu
View file @
69251362
...
@@ -138,9 +138,9 @@ void multi_tensor_axpby_cuda(
...
@@ -138,9 +138,9 @@ void multi_tensor_axpby_cuda(
// If build times suffer, think about where to put this dispatch,
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
2
][
0
].
scalar_type
(),
2
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
2
][
0
].
scalar_type
(),
2
,
"multi_tensor_axpby_cuda"
,
multi_tensor_apply
<
3
>
(
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
...
csrc/multi_tensor_l2norm_kernel.cu
View file @
69251362
...
@@ -322,7 +322,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
...
@@ -322,7 +322,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
ret_per_tensor
=
at
::
empty
({
0
},
float_options
);
ret_per_tensor
=
at
::
empty
({
0
},
float_options
);
}
}
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_l2norm_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_l2norm_cuda"
,
multi_tensor_apply
<
1
>
(
multi_tensor_apply
<
1
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
@@ -391,7 +391,7 @@ void multi_tensor_norm_out_cuda(
...
@@ -391,7 +391,7 @@ void multi_tensor_norm_out_cuda(
output_per_tensor
=
at
::
zeros
({
ntensors
*
max_chunks_per_tensor
},
float_options
);
output_per_tensor
=
at
::
zeros
({
ntensors
*
max_chunks_per_tensor
},
float_options
);
if
(
norm_type
==
0
)
{
if
(
norm_type
==
0
)
{
DISPATCH_FLOAT_AND_HALF
(
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_maxnorm_cuda"
,
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_maxnorm_cuda"
,
multi_tensor_apply
<
1
>
(
multi_tensor_apply
<
1
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
...
@@ -405,7 +405,7 @@ void multi_tensor_norm_out_cuda(
...
@@ -405,7 +405,7 @@ void multi_tensor_norm_out_cuda(
max_chunks_per_tensor
);)
max_chunks_per_tensor
);)
}
}
else
{
else
{
DISPATCH_FLOAT_AND_HALF
(
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_l2norm_cuda"
,
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_l2norm_cuda"
,
multi_tensor_apply
<
1
>
(
multi_tensor_apply
<
1
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
...
...
csrc/multi_tensor_lamb.cu
View file @
69251362
...
@@ -363,7 +363,7 @@ void multi_tensor_lamb_cuda(
...
@@ -363,7 +363,7 @@ void multi_tensor_lamb_cuda(
// We now in-place modify grad to store update before compute its norm
// We now in-place modify grad to store update before compute its norm
// Generally this is not a issue since people modify grad in step() method all the time
// Generally this is not a issue since people modify grad in step() method all the time
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
multi_tensor_apply
<
4
>
(
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
@@ -386,7 +386,7 @@ void multi_tensor_lamb_cuda(
...
@@ -386,7 +386,7 @@ void multi_tensor_lamb_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_param_list
(
tensor_lists
.
begin
(),
tensor_lists
.
begin
()
+
2
);
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_param_list
(
tensor_lists
.
begin
(),
tensor_lists
.
begin
()
+
2
);
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
multi_tensor_apply
<
2
>
(
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
...
csrc/multi_tensor_lamb_stage_1.cu
View file @
69251362
...
@@ -127,9 +127,9 @@ void multi_tensor_lamb_stage1_cuda(
...
@@ -127,9 +127,9 @@ void multi_tensor_lamb_stage1_cuda(
float
next_step
=
float
(
step
+
1
);
float
next_step
=
float
(
step
+
1
);
float
beta1_correction
=
1.0
f
-
std
::
pow
(
beta1
,
next_step
);
float
beta1_correction
=
1.0
f
-
std
::
pow
(
beta1
,
next_step
);
float
beta2_correction
=
1.0
f
-
std
::
pow
(
beta2
,
next_step
);
float
beta2_correction
=
1.0
f
-
std
::
pow
(
beta2
,
next_step
);
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
4
][
0
].
scalar_type
(),
2
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
4
][
0
].
scalar_type
(),
2
,
"lamb_stage_1"
,
multi_tensor_apply
<
5
>
(
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
...
csrc/multi_tensor_lamb_stage_2.cu
View file @
69251362
...
@@ -91,8 +91,8 @@ void multi_tensor_lamb_stage2_cuda(
...
@@ -91,8 +91,8 @@ void multi_tensor_lamb_stage2_cuda(
{
{
using
namespace
at
;
using
namespace
at
;
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_2"
,
multi_tensor_apply
<
2
>
(
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
...
csrc/multi_tensor_novograd.cu
View file @
69251362
...
@@ -164,7 +164,7 @@ void multi_tensor_novograd_cuda(
...
@@ -164,7 +164,7 @@ void multi_tensor_novograd_cuda(
multi_tensor_norm_out_cuda
(
chunk_size
,
noop_flag
,
grad_list
,
grad_norms
,
beta2
,
(
1.0
f
-
beta2
),
norm_type
);
multi_tensor_norm_out_cuda
(
chunk_size
,
noop_flag
,
grad_list
,
grad_norms
,
beta2
,
(
1.0
f
-
beta2
),
norm_type
);
// Assume single type across p,g,m1,m2 now
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF
(
DISPATCH_DOUBLE_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"novograd"
,
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"novograd"
,
multi_tensor_apply
<
3
>
(
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
...
...
csrc/multi_tensor_scale_kernel.cu
View file @
69251362
...
@@ -121,8 +121,8 @@ void multi_tensor_scale_cuda(
...
@@ -121,8 +121,8 @@ void multi_tensor_scale_cuda(
// If build times suffer, think about where to put this dispatch,
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_scale_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_scale_cuda"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_scale_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_scale_cuda"
,
multi_tensor_apply
<
2
>
(
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
...
csrc/multi_tensor_sgd_kernel.cu
View file @
69251362
...
@@ -166,6 +166,8 @@ void multi_tensor_sgd_cuda(
...
@@ -166,6 +166,8 @@ void multi_tensor_sgd_cuda(
// 2. fp32, fp32, fp32, No
// 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// 5. bfp16, bfp16, bfp16, No
// 6. bfp16, fp32, fp32, Yes
// It's easier to hardcode these possibilities than to use
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
// we don't want the majority of them.
...
@@ -268,6 +270,46 @@ void multi_tensor_sgd_cuda(
...
@@ -268,6 +270,46 @@ void multi_tensor_sgd_cuda(
wd_after_momentum
,
wd_after_momentum
,
scale
);
scale
);
}
}
// Case 5. bfp16, bfp16, bfp16, No
if
(
grad_type
==
at
::
ScalarType
::
BFloat16
&&
weight_type
==
at
::
ScalarType
::
BFloat16
&&
num_tensors
==
3
)
{
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
3
,
at
::
BFloat16
,
at
::
BFloat16
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
// Case 6. bfp16, fp32, fp32, Yes
else
if
(
grad_type
==
at
::
ScalarType
::
BFloat16
&&
weight_type
==
at
::
ScalarType
::
Float
&&
num_tensors
==
4
)
{
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
4
,
at
::
BFloat16
,
float
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
else
else
{
{
AT_ERROR
(
"multi_tensor_sgd only supports some combinations of gradient & weight types. Given: "
,
AT_ERROR
(
"multi_tensor_sgd only supports some combinations of gradient & weight types. Given: "
,
...
...
csrc/type_shim.h
View file @
69251362
...
@@ -79,6 +79,66 @@
...
@@ -79,6 +79,66 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
}
// TODO: We might have come up with an optimal set of dispatch macros by
// changing the signature to have an integer suffix of number of types
// to dispatch for as defined in upstream (e.g AT_DISPATCH_FLOATING_TYPES_AND2)
// Refactor once all the extension ops are enabled.
#define DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
__device__
__forceinline__
T
reduce_block_into_lanes
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment