Unverified Commit 31aceeaa authored by Deyu Fu's avatar Deyu Fu Committed by GitHub
Browse files

Improvements to apex.mlp (#804)

* update fused bias relu backward kernel

* adding support for not require first layer dgrad

* fix bug: wrong layer in requires grad

* add infrastructure for optional bias and activation, currently only support no bias and no relu

* make bias and relu optional separately

* add sigmoid activation option
parent aad9300b
...@@ -7,17 +7,19 @@ from .. import amp ...@@ -7,17 +7,19 @@ from .. import amp
class MlpFunction(torch.autograd.Function): class MlpFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, *args): def forward(ctx, bias, activation, *args):
output = mlp_cuda.forward(args) output = mlp_cuda.forward(bias, activation, args)
ctx.save_for_backward(*args) ctx.save_for_backward(*args)
ctx.outputs = output ctx.outputs = output
ctx.bias = bias
ctx.activation = activation
return output[0] return output[0]
@staticmethod @staticmethod
def backward(ctx, grad_o): def backward(ctx, grad_o):
grads = mlp_cuda.backward(grad_o, ctx.outputs, ctx.saved_tensors) grads = mlp_cuda.backward(ctx.bias, ctx.activation, grad_o, ctx.outputs, ctx.saved_tensors)
del ctx.outputs del ctx.outputs
return tuple(grads) return (None, None, *grads)
mlp_function = amp.half_function(MlpFunction.apply) mlp_function = amp.half_function(MlpFunction.apply)
...@@ -29,16 +31,21 @@ class MLP(torch.nn.Module): ...@@ -29,16 +31,21 @@ class MLP(torch.nn.Module):
bias (bool): Default True: bias (bool): Default True:
relu (bool): Default True relu (bool): Default True
""" """
def __init__(self, mlp_sizes, bias=True, relu=True): def __init__(self, mlp_sizes, bias=True, activation='relu'):
if not (bias and relu):
raise TypeError("bias and relu must be both true.")
super(MLP, self).__init__() super(MLP, self).__init__()
self.num_layers = len(mlp_sizes) - 1 self.num_layers = len(mlp_sizes) - 1
self.mlp_sizes = copy(mlp_sizes) self.mlp_sizes = copy(mlp_sizes)
self.bias = bias self.bias = 1 if bias else 0
self.relu= relu
if activation is 'none':
self.activation = 0
elif activation is 'relu':
self.activation = 1
elif activation is 'sigmoid':
self.activation = 2
else:
raise TypeError("activation must be relu or none.")
# ignoring bias = False now
self.weights = [] self.weights = []
self.biases = [] self.biases = []
for i in range(self.num_layers): for i in range(self.num_layers):
...@@ -46,6 +53,7 @@ class MLP(torch.nn.Module): ...@@ -46,6 +53,7 @@ class MLP(torch.nn.Module):
self.weights.append(w) self.weights.append(w)
name = 'weight_{}'.format(i) name = 'weight_{}'.format(i)
setattr(self, name, w) setattr(self, name, w)
if self.bias:
b = torch.nn.Parameter(torch.empty(mlp_sizes[i+1])) b = torch.nn.Parameter(torch.empty(mlp_sizes[i+1]))
self.biases.append(b) self.biases.append(b)
name = 'bias_{}'.format(i) name = 'bias_{}'.format(i)
...@@ -58,13 +66,14 @@ class MLP(torch.nn.Module): ...@@ -58,13 +66,14 @@ class MLP(torch.nn.Module):
dimsum = weight.size(0) + weight.size(1) dimsum = weight.size(0) + weight.size(1)
std = math.sqrt(2. / float(dimsum)) std = math.sqrt(2. / float(dimsum))
nn.init.normal_(weight, 0., std) nn.init.normal_(weight, 0., std)
if self.bias:
for bias in self.biases: for bias in self.biases:
std = math.sqrt(1. / float(bias.size(0))) std = math.sqrt(1. / float(bias.size(0)))
nn.init.normal_(bias, 0., std) nn.init.normal_(bias, 0., std)
def forward(self, input): def forward(self, input):
return mlp_function(input, *self.weights, *self.biases) return mlp_function(self.bias, self.activation, input, *self.weights, *self.biases)
def extra_repr(self): def extra_repr(self):
s = F"MLP sizes: {self.mlp_sizes}, Bias={self.bias}, ReLU={self.relu}" s = F"MLP sizes: {self.mlp_sizes}, Bias={self.bias}, activation={self.activation}"
return s return s
...@@ -19,7 +19,9 @@ int mlp_fp( ...@@ -19,7 +19,9 @@ int mlp_fp(
int* output_features, int* output_features,
T** BPtr, T** BPtr,
T* Y, T* Y,
T* reserved_space); T* reserved_space,
int use_bias,
int activation);
template <typename T> template <typename T>
int mlp_bp( int mlp_bp(
...@@ -35,11 +37,18 @@ int mlp_bp( ...@@ -35,11 +37,18 @@ int mlp_bp(
T* work_space, T* work_space,
T* dX, T* dX,
T** dwPtr, T** dwPtr,
T** dbPtr); T** dbPtr,
bool requires_grad,
int use_bias,
int activation);
std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at::Tensor> inputs) {
std::vector<at::Tensor> mlp_forward(std::vector<at::Tensor> inputs) { auto num_layers = inputs.size() - 1;
if (use_bias) {
// inputs contains (input, weights, biases) // inputs contains (input, weights, biases)
auto num_layers = (inputs.size() - 1) / 2; num_layers /= 2;
}
auto batch_size = inputs[0].size(0); auto batch_size = inputs[0].size(0);
auto input_features = inputs[0].size(1); auto input_features = inputs[0].size(1);
...@@ -60,8 +69,10 @@ std::vector<at::Tensor> mlp_forward(std::vector<at::Tensor> inputs) { ...@@ -60,8 +69,10 @@ std::vector<at::Tensor> mlp_forward(std::vector<at::Tensor> inputs) {
std::vector<scalar_t*> b_ptr; std::vector<scalar_t*> b_ptr;
for (int i = 0; i < num_layers; i++) { for (int i = 0; i < num_layers; i++) {
w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>()); w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());
if (use_bias) {
b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>()); b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>());
} }
}
auto result = mlp_fp<scalar_t>( auto result = mlp_fp<scalar_t>(
inputs[0].data_ptr<scalar_t>(), inputs[0].data_ptr<scalar_t>(),
input_features, input_features,
...@@ -71,37 +82,48 @@ std::vector<at::Tensor> mlp_forward(std::vector<at::Tensor> inputs) { ...@@ -71,37 +82,48 @@ std::vector<at::Tensor> mlp_forward(std::vector<at::Tensor> inputs) {
output_features.data(), output_features.data(),
b_ptr.data(), b_ptr.data(),
out.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(),
reserved_space.data_ptr<scalar_t>()); reserved_space.data_ptr<scalar_t>(),
use_bias,
activation);
}); });
return {out, reserved_space}; return {out, reserved_space};
} }
std::vector<at::Tensor> mlp_backward( std::vector<at::Tensor> mlp_backward(
int use_bias,
int activation,
at::Tensor grad_o, at::Tensor grad_o,
std::vector<at::Tensor> fprop_outputs, std::vector<at::Tensor> fprop_outputs,
std::vector<at::Tensor> inputs) { std::vector<at::Tensor> inputs) {
// same code to get sizes and W pointers
auto num_layers = (inputs.size() - 1) / 2; auto num_layers = inputs.size() - 1;
if (use_bias) {
// inputs contains (input, weights, biases)
num_layers /= 2;
}
auto batch_size = inputs[0].size(0); auto batch_size = inputs[0].size(0);
auto input_features = inputs[0].size(1); auto input_features = inputs[0].size(1);
// TODO: not creating empty tensor for it?
bool requires_grad = inputs[0].requires_grad();
std::vector<int> output_features; std::vector<int> output_features;
for (int i = 0; i < num_layers; i++) { for (int i = 0; i < num_layers; i++) {
output_features.push_back(inputs[i + 1].size(0)); output_features.push_back(inputs[i + 1].size(0));
} }
// create outputs, length of inputs // create outputs, length of inputs
// TODO: not create bias if not needed
std::vector<at::Tensor> outputs; std::vector<at::Tensor> outputs;
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < inputs.size(); i++) {
outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type())); // clone for testing now outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type())); // clone for testing now
} }
AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_forward", [&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_backward", [&] {
std::vector<scalar_t*> w_ptr; std::vector<scalar_t*> w_ptr;
std::vector<scalar_t*> b_ptr;
for (int i = 0; i < num_layers; i++) { for (int i = 0; i < num_layers; i++) {
w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>()); w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());
b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>());
} }
std::vector<scalar_t*> outputs_ptr; std::vector<scalar_t*> outputs_ptr;
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < inputs.size(); i++) {
...@@ -127,7 +149,10 @@ std::vector<at::Tensor> mlp_backward( ...@@ -127,7 +149,10 @@ std::vector<at::Tensor> mlp_backward(
work_space.data_ptr<scalar_t>(), work_space.data_ptr<scalar_t>(),
outputs_ptr[0], outputs_ptr[0],
outputs_ptr.data() + 1, outputs_ptr.data() + 1,
outputs_ptr.data() + 1 + num_layers); outputs_ptr.data() + 1 + num_layers,
requires_grad,
use_bias,
activation);
}); });
return outputs; return outputs;
......
This diff is collapsed.
...@@ -51,6 +51,116 @@ class TestMLP(unittest.TestCase): ...@@ -51,6 +51,116 @@ class TestMLP(unittest.TestCase):
ref_mlp[0].bias.grad.detach().cpu().numpy(), ref_mlp[0].bias.grad.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5) atol=1e-7, rtol=1e-5)
def test_no_bias(self):
for use_activation in ['none', 'relu', 'sigmoid']:
mlp = MLP(mlp_sizes, bias=False, activation=use_activation).cuda()
mlp_layers = []
for i in range(mlp.num_layers):
linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1], bias=False)
mlp.weights[i].data.copy_(linear.weight)
mlp_layers.append(linear)
if use_activation == 'relu':
mlp_layers.append(nn.ReLU(inplace=True))
if use_activation == 'sigmoid':
mlp_layers.append(nn.Sigmoid())
ref_mlp = nn.Sequential(*mlp_layers).cuda()
test_input = torch.empty(batch_size, mlp_sizes[0], device="cuda").uniform_(-1., 1.).requires_grad_()
ref_input = test_input.clone().detach().requires_grad_()
mlp_out = mlp(test_input)
ref_out = ref_mlp(ref_input)
np.testing.assert_allclose(
mlp_out.detach().cpu().numpy(),
ref_out.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5)
# Use mean value as scalar loss. Multiply 10 to make it big enough not zero out
mlp_out.mean().mul(10.).backward()
ref_out.mean().mul(10.).backward()
np.testing.assert_allclose(
test_input.grad.detach().cpu().numpy(),
ref_input.grad.detach().cpu().numpy(),
atol=0, rtol=100)
np.testing.assert_allclose(
mlp.weights[0].grad.detach().cpu().numpy(),
ref_mlp[0].weight.grad.detach().cpu().numpy(),
atol=1e-7, rtol=100)
def test_with_bias(self):
for use_activation in ['none', 'relu', 'sigmoid']:
mlp = MLP(mlp_sizes, bias=True, activation=use_activation).cuda()
mlp_layers = []
for i in range(mlp.num_layers):
linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1], bias=True)
mlp.weights[i].data.copy_(linear.weight)
mlp.biases[i].data.copy_(linear.bias)
mlp_layers.append(linear)
if use_activation == 'relu':
mlp_layers.append(nn.ReLU(inplace=True))
if use_activation == 'sigmoid':
mlp_layers.append(nn.Sigmoid())
ref_mlp = nn.Sequential(*mlp_layers).cuda()
test_input = torch.empty(batch_size, mlp_sizes[0], device="cuda").uniform_(-1., 1.).requires_grad_()
ref_input = test_input.clone().detach().requires_grad_()
mlp_out = mlp(test_input)
ref_out = ref_mlp(ref_input)
np.testing.assert_allclose(
mlp_out.detach().cpu().numpy(),
ref_out.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5)
# Use mean value as scalar loss. Multiply 10 to make it big enough not zero out
mlp_out.mean().mul(10.).backward()
ref_out.mean().mul(10.).backward()
np.testing.assert_allclose(
test_input.grad.detach().cpu().numpy(),
ref_input.grad.detach().cpu().numpy(),
atol=0, rtol=1)
np.testing.assert_allclose(
mlp.weights[0].grad.detach().cpu().numpy(),
ref_mlp[0].weight.grad.detach().cpu().numpy(),
atol=1e-7, rtol=1)
np.testing.assert_allclose(
mlp.biases[0].grad.detach().cpu().numpy(),
ref_mlp[0].bias.grad.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5)
def test_no_grad(self):
mlp = MLP(mlp_sizes).cuda()
mlp_layers = []
for i in range(mlp.num_layers):
linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1])
mlp.weights[i].data.copy_(linear.weight)
mlp.biases[i].data.copy_(linear.bias)
mlp_layers.append(linear)
mlp_layers.append(nn.ReLU(inplace=True))
ref_mlp = nn.Sequential(*mlp_layers).cuda()
test_input = torch.empty(batch_size, mlp_sizes[0], device="cuda").uniform_(-1., 1.)
ref_input = test_input.clone().detach()
mlp_out = mlp(test_input)
ref_out = ref_mlp(ref_input)
np.testing.assert_allclose(
mlp_out.detach().cpu().numpy(),
ref_out.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5)
# Use mean value as scalar loss. Multiply 10 to make it big enough not zero out
mlp_out.mean().mul(10.).backward()
ref_out.mean().mul(10.).backward()
np.testing.assert_allclose(
mlp.weights[0].grad.detach().cpu().numpy(),
ref_mlp[0].weight.grad.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5)
def test_performance_half(self): def test_performance_half(self):
mlp = MLP(mlp_sizes).cuda().half() mlp = MLP(mlp_sizes).cuda().half()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment