Commit 75c8a97a authored by Simon Layton's avatar Simon Layton
Browse files

Simplify C++-side logic

Only support the 4 specific cases we care about
Remove more general set of switch statements
parent cac061a1
......@@ -151,45 +151,57 @@ void multi_tensor_sgd_cuda(
bool first_run)
{
auto num_tensors = tensor_lists.size();
switch (num_tensors) {
case 3:
switch (tensor_lists[0][0].type().scalarType()) {
case at::ScalarType::Half:
auto grad_type = tensor_lists[0][0].type().scalarType();
auto weight_type = tensor_lists[0][0].type().scalarType();
// We have 4 potentials to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No
// 2. fp16, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, No
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No
if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Half &&
num_tensors == 3) {
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<3, at::Half, float>(),
SGDFunctor<3, at::Half, at::Half>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run);
break;
case at::ScalarType::Float:
}
// Case 2. fp16, fp32, fp32, No
else if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Float &&
num_tensors == 3) {
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<3, float, float>(),
SGDFunctor<3, at::Half, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run);
break;
default:
AT_ERROR("multi_tensor_sgd only takes Half and Float gradients, given: ", tensor_lists[0][0].type().scalarType());
}
break;
case 4:
switch (tensor_lists[0][0].type().scalarType()) {
case at::ScalarType::Half:
// Case 3. fp16, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Float &&
num_tensors == 4) {
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
......@@ -202,26 +214,27 @@ void multi_tensor_sgd_cuda(
lr,
nesterov,
first_run);
break;
case at::ScalarType::Float:
multi_tensor_apply<4>(
}
// Case 4. fp32, fp32, fp32, No
else if (grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float &&
num_tensors == 3) {
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, float, float>(),
SGDFunctor<3, float, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run);
break;
default:
AT_ERROR("multi_tensor_sgd only takes Half and Float gradients, given: ", tensor_lists[0][0].type().scalarType());
}
default:
AT_ERROR("multi_tensor_sgd takes either 3 or 4 sets of tensors, given ", num_tensors);
else {
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
}
AT_CUDA_CHECK(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment