Commit 69251362 authored by rohithkrn's avatar rohithkrn
Browse files

enable multi tensor extension for bfloat16

parent cec08a41
......@@ -13,7 +13,7 @@ class AmpOptimizerState(object):
def _master_params_to_model_params(self):
stash = self._amp_stash
if multi_tensor_applier.available and _amp_state.opt_properties.opt_level not in {"O4", "O5"}:
if multi_tensor_applier.available:
if len(stash.all_fp16_params) > 0:
multi_tensor_applier(
stash.multi_tensor_scale,
......@@ -337,7 +337,7 @@ def _process_optimizer(optimizer, properties):
raise RuntimeError("Incoming optimizer already has {} defined.".format(name))
# TODO: Centralize exposure and import error checking for the C backend.
if multi_tensor_applier.available and not properties.opt_level in {"O4", "O5"}:
if multi_tensor_applier.available:
import amp_C
optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale
optimizer._amp_stash.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
......
......@@ -63,7 +63,7 @@ class LossScaler(object):
self._unskipped = 0
self._has_overflow = False
self._overflow_buf = torch.cuda.IntTensor([0])
if multi_tensor_applier.available and _amp_state.opt_properties.opt_level not in {"O4", "O5"}:
if multi_tensor_applier.available:
import amp_C
LossScaler.has_fused_kernel = multi_tensor_applier.available
LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale
......
......@@ -149,7 +149,7 @@ void multi_tensor_adam_cuda(
}
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF(
DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(
tensor_lists[0][0].scalar_type(), 0, "adam",
multi_tensor_apply<4>(
BLOCK_SIZE,
......
......@@ -138,9 +138,9 @@ void multi_tensor_axpby_cuda(
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_axpby_cuda",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda",
DISPATCH_FLOAT_AND_HALF(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_axpby_cuda",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda",
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
......
......@@ -322,7 +322,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
ret_per_tensor = at::empty({0}, float_options);
}
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>(
BLOCK_SIZE,
chunk_size,
......@@ -391,7 +391,7 @@ void multi_tensor_norm_out_cuda(
output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options);
if (norm_type == 0) {
DISPATCH_FLOAT_AND_HALF(
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda",
multi_tensor_apply<1>(
BLOCK_SIZE,
......@@ -405,7 +405,7 @@ void multi_tensor_norm_out_cuda(
max_chunks_per_tensor);)
}
else {
DISPATCH_FLOAT_AND_HALF(
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>(
BLOCK_SIZE,
......
......@@ -363,7 +363,7 @@ void multi_tensor_lamb_cuda(
// We now in-place modify grad to store update before compute its norm
// Generally this is not a issue since people modify grad in step() method all the time
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
......@@ -386,7 +386,7 @@ void multi_tensor_lamb_cuda(
std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2);
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
......
......@@ -127,9 +127,9 @@ void multi_tensor_lamb_stage1_cuda(
float next_step = float(step+1);
float beta1_correction = 1.0f - std::pow(beta1, next_step);
float beta2_correction = 1.0f - std::pow(beta2, next_step);
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
......
......@@ -91,8 +91,8 @@ void multi_tensor_lamb_stage2_cuda(
{
using namespace at;
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_2",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_2",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
......
......@@ -164,7 +164,7 @@ void multi_tensor_novograd_cuda(
multi_tensor_norm_out_cuda(chunk_size, noop_flag, grad_list, grad_norms, beta2, (1.0f - beta2), norm_type);
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF(
DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(
tensor_lists[0][0].scalar_type(), 0, "novograd",
multi_tensor_apply<3>(
BLOCK_SIZE,
......
......@@ -121,8 +121,8 @@ void multi_tensor_scale_cuda(
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
......
......@@ -166,6 +166,8 @@ void multi_tensor_sgd_cuda(
// 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// 5. bfp16, bfp16, bfp16, No
// 6. bfp16, fp32, fp32, Yes
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
......@@ -268,6 +270,46 @@ void multi_tensor_sgd_cuda(
wd_after_momentum,
scale);
}
// Case 5. bfp16, bfp16, bfp16, No
if(grad_type == at::ScalarType::BFloat16 &&
weight_type == at::ScalarType::BFloat16 &&
num_tensors == 3)
{
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<3, at::BFloat16, at::BFloat16>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
// Case 6. bfp16, fp32, fp32, Yes
else if(grad_type == at::ScalarType::BFloat16 &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, at::BFloat16, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
else
{
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
......
......@@ -79,6 +79,66 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
// TODO: We might have come up with an optimal set of dispatch macros by
// changing the signature to have an integer suffix of number of types
// to dispatch for as defined in upstream (e.g AT_DISPATCH_FLOATING_TYPES_AND2)
// Refactor once all the extension ops are enabled.
#define DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment