Commit 75c8a97a authored by Simon Layton's avatar Simon Layton
Browse files

Simplify C++-side logic

Only support the 4 specific cases we care about
Remove more general set of switch statements
parent cac061a1
...@@ -151,77 +151,90 @@ void multi_tensor_sgd_cuda( ...@@ -151,77 +151,90 @@ void multi_tensor_sgd_cuda(
bool first_run) bool first_run)
{ {
auto num_tensors = tensor_lists.size(); auto num_tensors = tensor_lists.size();
auto grad_type = tensor_lists[0][0].type().scalarType();
switch (num_tensors) { auto weight_type = tensor_lists[0][0].type().scalarType();
case 3:
switch (tensor_lists[0][0].type().scalarType()) { // We have 4 potentials to handle here, in terms of
case at::ScalarType::Half: // grad_type, param_type, momentum_type, requires_fp16_copy
multi_tensor_apply<3>( // 1. fp16, fp16, fp16, No
BLOCK_SIZE, // 2. fp16, fp32, fp32, No
chunk_size, // 3. fp16, fp32, fp32, Yes
noop_flag, // 4. fp32, fp32, fp32, No
tensor_lists, // It's easier to hardcode these possibilities than to use
SGDFunctor<3, at::Half, float>(), // switches etc. to handle the cross-product of cases where
wd, // we don't want the majority of them.
momentum,
dampening, // Case 1. fp16, fp16, fp16, No
lr, if (grad_type == at::ScalarType::Half &&
nesterov, weight_type == at::ScalarType::Half &&
first_run); num_tensors == 3) {
break; multi_tensor_apply<3>(
case at::ScalarType::Float: BLOCK_SIZE,
multi_tensor_apply<3>( chunk_size,
BLOCK_SIZE, noop_flag,
chunk_size, tensor_lists,
noop_flag, SGDFunctor<3, at::Half, at::Half>(),
tensor_lists, wd,
SGDFunctor<3, float, float>(), momentum,
wd, dampening,
momentum, lr,
dampening, nesterov,
lr, first_run);
nesterov, }
first_run); // Case 2. fp16, fp32, fp32, No
break; else if (grad_type == at::ScalarType::Half &&
default: weight_type == at::ScalarType::Float &&
AT_ERROR("multi_tensor_sgd only takes Half and Float gradients, given: ", tensor_lists[0][0].type().scalarType()); num_tensors == 3) {
} multi_tensor_apply<3>(
break; BLOCK_SIZE,
case 4: chunk_size,
switch (tensor_lists[0][0].type().scalarType()) { noop_flag,
case at::ScalarType::Half: tensor_lists,
multi_tensor_apply<4>( SGDFunctor<3, at::Half, float>(),
BLOCK_SIZE, wd,
chunk_size, momentum,
noop_flag, dampening,
tensor_lists, lr,
SGDFunctor<4, at::Half, float>(), nesterov,
wd, first_run);
momentum, }
dampening, // Case 3. fp16, fp32, fp32, Yes
lr, else if (grad_type == at::ScalarType::Half &&
nesterov, weight_type == at::ScalarType::Float &&
first_run); num_tensors == 4) {
break; multi_tensor_apply<4>(
case at::ScalarType::Float: BLOCK_SIZE,
multi_tensor_apply<4>( chunk_size,
BLOCK_SIZE, noop_flag,
chunk_size, tensor_lists,
noop_flag, SGDFunctor<4, at::Half, float>(),
tensor_lists, wd,
SGDFunctor<4, float, float>(), momentum,
wd, dampening,
momentum, lr,
dampening, nesterov,
lr, first_run);
nesterov, }
first_run); // Case 4. fp32, fp32, fp32, No
break; else if (grad_type == at::ScalarType::Float &&
default: weight_type == at::ScalarType::Float &&
AT_ERROR("multi_tensor_sgd only takes Half and Float gradients, given: ", tensor_lists[0][0].type().scalarType()); num_tensors == 3) {
} multi_tensor_apply<3>(
default: BLOCK_SIZE,
AT_ERROR("multi_tensor_sgd takes either 3 or 4 sets of tensors, given ", num_tensors); chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<3, float, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run);
}
else {
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
} }
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment