Commit 75c8a97a authored by Simon Layton's avatar Simon Layton
Browse files

Simplify C++-side logic

Only support the 4 specific cases we care about
Remove more general set of switch statements
parent cac061a1
...@@ -151,45 +151,57 @@ void multi_tensor_sgd_cuda( ...@@ -151,45 +151,57 @@ void multi_tensor_sgd_cuda(
bool first_run) bool first_run)
{ {
auto num_tensors = tensor_lists.size(); auto num_tensors = tensor_lists.size();
auto grad_type = tensor_lists[0][0].type().scalarType();
switch (num_tensors) { auto weight_type = tensor_lists[0][0].type().scalarType();
case 3:
switch (tensor_lists[0][0].type().scalarType()) { // We have 4 potentials to handle here, in terms of
case at::ScalarType::Half: // grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No
// 2. fp16, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, No
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No
if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Half &&
num_tensors == 3) {
multi_tensor_apply<3>( multi_tensor_apply<3>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, noop_flag,
tensor_lists, tensor_lists,
SGDFunctor<3, at::Half, float>(), SGDFunctor<3, at::Half, at::Half>(),
wd, wd,
momentum, momentum,
dampening, dampening,
lr, lr,
nesterov, nesterov,
first_run); first_run);
break; }
case at::ScalarType::Float: // Case 2. fp16, fp32, fp32, No
else if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Float &&
num_tensors == 3) {
multi_tensor_apply<3>( multi_tensor_apply<3>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, noop_flag,
tensor_lists, tensor_lists,
SGDFunctor<3, float, float>(), SGDFunctor<3, at::Half, float>(),
wd, wd,
momentum, momentum,
dampening, dampening,
lr, lr,
nesterov, nesterov,
first_run); first_run);
break;
default:
AT_ERROR("multi_tensor_sgd only takes Half and Float gradients, given: ", tensor_lists[0][0].type().scalarType());
} }
break; // Case 3. fp16, fp32, fp32, Yes
case 4: else if (grad_type == at::ScalarType::Half &&
switch (tensor_lists[0][0].type().scalarType()) { weight_type == at::ScalarType::Float &&
case at::ScalarType::Half: num_tensors == 4) {
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
...@@ -202,26 +214,27 @@ void multi_tensor_sgd_cuda( ...@@ -202,26 +214,27 @@ void multi_tensor_sgd_cuda(
lr, lr,
nesterov, nesterov,
first_run); first_run);
break; }
case at::ScalarType::Float: // Case 4. fp32, fp32, fp32, No
multi_tensor_apply<4>( else if (grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float &&
num_tensors == 3) {
multi_tensor_apply<3>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, noop_flag,
tensor_lists, tensor_lists,
SGDFunctor<4, float, float>(), SGDFunctor<3, float, float>(),
wd, wd,
momentum, momentum,
dampening, dampening,
lr, lr,
nesterov, nesterov,
first_run); first_run);
break;
default:
AT_ERROR("multi_tensor_sgd only takes Half and Float gradients, given: ", tensor_lists[0][0].type().scalarType());
} }
default: else {
AT_ERROR("multi_tensor_sgd takes either 3 or 4 sets of tensors, given ", num_tensors); AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
} }
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment