Commit 69251362 authored by rohithkrn's avatar rohithkrn
Browse files

enable multi tensor extension for bfloat16

parent cec08a41
...@@ -13,7 +13,7 @@ class AmpOptimizerState(object): ...@@ -13,7 +13,7 @@ class AmpOptimizerState(object):
def _master_params_to_model_params(self): def _master_params_to_model_params(self):
stash = self._amp_stash stash = self._amp_stash
if multi_tensor_applier.available and _amp_state.opt_properties.opt_level not in {"O4", "O5"}: if multi_tensor_applier.available:
if len(stash.all_fp16_params) > 0: if len(stash.all_fp16_params) > 0:
multi_tensor_applier( multi_tensor_applier(
stash.multi_tensor_scale, stash.multi_tensor_scale,
...@@ -337,7 +337,7 @@ def _process_optimizer(optimizer, properties): ...@@ -337,7 +337,7 @@ def _process_optimizer(optimizer, properties):
raise RuntimeError("Incoming optimizer already has {} defined.".format(name)) raise RuntimeError("Incoming optimizer already has {} defined.".format(name))
# TODO: Centralize exposure and import error checking for the C backend. # TODO: Centralize exposure and import error checking for the C backend.
if multi_tensor_applier.available and not properties.opt_level in {"O4", "O5"}: if multi_tensor_applier.available:
import amp_C import amp_C
optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale
optimizer._amp_stash.multi_tensor_l2norm = amp_C.multi_tensor_l2norm optimizer._amp_stash.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
......
...@@ -63,7 +63,7 @@ class LossScaler(object): ...@@ -63,7 +63,7 @@ class LossScaler(object):
self._unskipped = 0 self._unskipped = 0
self._has_overflow = False self._has_overflow = False
self._overflow_buf = torch.cuda.IntTensor([0]) self._overflow_buf = torch.cuda.IntTensor([0])
if multi_tensor_applier.available and _amp_state.opt_properties.opt_level not in {"O4", "O5"}: if multi_tensor_applier.available:
import amp_C import amp_C
LossScaler.has_fused_kernel = multi_tensor_applier.available LossScaler.has_fused_kernel = multi_tensor_applier.available
LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale
......
...@@ -149,7 +149,7 @@ void multi_tensor_adam_cuda( ...@@ -149,7 +149,7 @@ void multi_tensor_adam_cuda(
} }
// Assume single type across p,g,m1,m2 now // Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF( DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(
tensor_lists[0][0].scalar_type(), 0, "adam", tensor_lists[0][0].scalar_type(), 0, "adam",
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
......
...@@ -138,9 +138,9 @@ void multi_tensor_axpby_cuda( ...@@ -138,9 +138,9 @@ void multi_tensor_axpby_cuda(
// If build times suffer, think about where to put this dispatch, // If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply. // and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_axpby_cuda", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_axpby_cuda",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda",
DISPATCH_FLOAT_AND_HALF(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda",
multi_tensor_apply<3>( multi_tensor_apply<3>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
......
...@@ -322,7 +322,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda( ...@@ -322,7 +322,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
ret_per_tensor = at::empty({0}, float_options); ret_per_tensor = at::empty({0}, float_options);
} }
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>( multi_tensor_apply<1>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
...@@ -391,7 +391,7 @@ void multi_tensor_norm_out_cuda( ...@@ -391,7 +391,7 @@ void multi_tensor_norm_out_cuda(
output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options); output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options);
if (norm_type == 0) { if (norm_type == 0) {
DISPATCH_FLOAT_AND_HALF( DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda", tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda",
multi_tensor_apply<1>( multi_tensor_apply<1>(
BLOCK_SIZE, BLOCK_SIZE,
...@@ -405,7 +405,7 @@ void multi_tensor_norm_out_cuda( ...@@ -405,7 +405,7 @@ void multi_tensor_norm_out_cuda(
max_chunks_per_tensor);) max_chunks_per_tensor);)
} }
else { else {
DISPATCH_FLOAT_AND_HALF( DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>( multi_tensor_apply<1>(
BLOCK_SIZE, BLOCK_SIZE,
......
...@@ -363,7 +363,7 @@ void multi_tensor_lamb_cuda( ...@@ -363,7 +363,7 @@ void multi_tensor_lamb_cuda(
// We now in-place modify grad to store update before compute its norm // We now in-place modify grad to store update before compute its norm
// Generally this is not a issue since people modify grad in step() method all the time // Generally this is not a issue since people modify grad in step() method all the time
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code // We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
...@@ -386,7 +386,7 @@ void multi_tensor_lamb_cuda( ...@@ -386,7 +386,7 @@ void multi_tensor_lamb_cuda(
std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2); std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2);
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
multi_tensor_apply<2>( multi_tensor_apply<2>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
......
...@@ -127,9 +127,9 @@ void multi_tensor_lamb_stage1_cuda( ...@@ -127,9 +127,9 @@ void multi_tensor_lamb_stage1_cuda(
float next_step = float(step+1); float next_step = float(step+1);
float beta1_correction = 1.0f - std::pow(beta1, next_step); float beta1_correction = 1.0f - std::pow(beta1, next_step);
float beta2_correction = 1.0f - std::pow(beta2, next_step); float beta2_correction = 1.0f - std::pow(beta2, next_step);
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_1", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1",
multi_tensor_apply<5>( multi_tensor_apply<5>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
......
...@@ -91,8 +91,8 @@ void multi_tensor_lamb_stage2_cuda( ...@@ -91,8 +91,8 @@ void multi_tensor_lamb_stage2_cuda(
{ {
using namespace at; using namespace at;
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_2", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_2",
multi_tensor_apply<2>( multi_tensor_apply<2>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
......
...@@ -164,7 +164,7 @@ void multi_tensor_novograd_cuda( ...@@ -164,7 +164,7 @@ void multi_tensor_novograd_cuda(
multi_tensor_norm_out_cuda(chunk_size, noop_flag, grad_list, grad_norms, beta2, (1.0f - beta2), norm_type); multi_tensor_norm_out_cuda(chunk_size, noop_flag, grad_list, grad_norms, beta2, (1.0f - beta2), norm_type);
// Assume single type across p,g,m1,m2 now // Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF( DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(
tensor_lists[0][0].scalar_type(), 0, "novograd", tensor_lists[0][0].scalar_type(), 0, "novograd",
multi_tensor_apply<3>( multi_tensor_apply<3>(
BLOCK_SIZE, BLOCK_SIZE,
......
...@@ -121,8 +121,8 @@ void multi_tensor_scale_cuda( ...@@ -121,8 +121,8 @@ void multi_tensor_scale_cuda(
// If build times suffer, think about where to put this dispatch, // If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply. // and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
multi_tensor_apply<2>( multi_tensor_apply<2>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
......
...@@ -166,6 +166,8 @@ void multi_tensor_sgd_cuda( ...@@ -166,6 +166,8 @@ void multi_tensor_sgd_cuda(
// 2. fp32, fp32, fp32, No // 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes // 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// 5. bfp16, bfp16, bfp16, No
// 6. bfp16, fp32, fp32, Yes
// It's easier to hardcode these possibilities than to use // It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where // switches etc. to handle the cross-product of cases where
// we don't want the majority of them. // we don't want the majority of them.
...@@ -268,6 +270,46 @@ void multi_tensor_sgd_cuda( ...@@ -268,6 +270,46 @@ void multi_tensor_sgd_cuda(
wd_after_momentum, wd_after_momentum,
scale); scale);
} }
// Case 5. bfp16, bfp16, bfp16, No
if(grad_type == at::ScalarType::BFloat16 &&
weight_type == at::ScalarType::BFloat16 &&
num_tensors == 3)
{
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<3, at::BFloat16, at::BFloat16>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
// Case 6. bfp16, fp32, fp32, Yes
else if(grad_type == at::ScalarType::BFloat16 &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, at::BFloat16, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
else else
{ {
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ", AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
......
...@@ -79,6 +79,66 @@ ...@@ -79,6 +79,66 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} }
// TODO: We might have come up with an optimal set of dispatch macros by
// changing the signature to have an integer suffix of number of types
// to dispatch for as defined in upstream (e.g AT_DISPATCH_FLOATING_TYPES_AND2)
// Refactor once all the extension ops are enabled.
#define DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template<typename T> template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes __device__ __forceinline__ T reduce_block_into_lanes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment