Commit c9d35a49 authored by rohithkrn's avatar rohithkrn
Browse files

fix bf16 layernorm bug

parent 37989915
...@@ -130,7 +130,8 @@ std::vector<at::Tensor> layer_norm( ...@@ -130,7 +130,8 @@ std::vector<at::Tensor> layer_norm(
int n1,n2; int n1,n2;
check_args(input,normalized_shape,n1,n2); check_args(input,normalized_shape,n1,n2);
at::Tensor output = at::empty_like(input); at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type())); at::Tensor mean = at::empty({n1}, input.options().dtype((input.scalar_type()==at::ScalarType::Half ||
input.scalar_type()==at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean); at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon); normalized_shape,NULL,NULL,epsilon);
...@@ -152,7 +153,8 @@ std::vector<at::Tensor> layer_norm_affine( ...@@ -152,7 +153,8 @@ std::vector<at::Tensor> layer_norm_affine(
int n1,n2; int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2); check_args(input,normalized_shape,gamma,beta,n1,n2);
at::Tensor output = at::empty_like(input); at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type())); at::Tensor mean = at::empty({n1}, input.options().dtype((input.scalar_type()==at::ScalarType::Half ||
input.scalar_type()==at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean); at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon); normalized_shape,&gamma,&beta,epsilon);
......
...@@ -730,7 +730,8 @@ void HostLayerNormGradient( ...@@ -730,7 +730,8 @@ void HostLayerNormGradient(
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U); const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input->scalar_type())); at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype((input->scalar_type()==at::ScalarType::Half ||
input->scalar_type()==at::ScalarType::BFloat16) ? at::ScalarType::Float : input->scalar_type()));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>( cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout, dout,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment