Commit e515f026 authored by hwijeen's avatar hwijeen
Browse files

fix typo

parent 90e0a0dd
...@@ -53,8 +53,7 @@ class ParallelMLP(MegatronModule): ...@@ -53,8 +53,7 @@ class ParallelMLP(MegatronModule):
MLP will take the input with h hidden state, project it to 4*h MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also state back into h hidden dimension.
applied.
""" """
def __init__(self, init_method, output_layer_init_method): def __init__(self, init_method, output_layer_init_method):
...@@ -84,7 +83,6 @@ class ParallelMLP(MegatronModule): ...@@ -84,7 +83,6 @@ class ParallelMLP(MegatronModule):
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
def forward(self, hidden_states): def forward(self, hidden_states):
# [s, b, 4hp] # [s, b, 4hp]
......
...@@ -256,7 +256,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -256,7 +256,7 @@ class ColumnParallelLinear(torch.nn.Module):
device=torch.cuda.current_device(), dtype=args.params_dtype)) device=torch.cuda.current_device(), dtype=args.params_dtype))
_initialize_affine_weight_gpu(self.weight, init_method, _initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=stride) partition_dim=0, stride=stride)
if bias: if bias:
if args.use_cpu_initialization: if args.use_cpu_initialization:
self.bias = Parameter(torch.empty( self.bias = Parameter(torch.empty(
...@@ -286,7 +286,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -286,7 +286,7 @@ class ColumnParallelLinear(torch.nn.Module):
# All-gather across the partitions. # All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
else: else:
output = output_parallel output = output_parallel
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
return output, output_bias return output, output_bias
...@@ -316,8 +316,8 @@ class RowParallelLinear(torch.nn.Module): ...@@ -316,8 +316,8 @@ class RowParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights set to False. It returns the master weights
used for initialization. used for initialization.
skip_bias_add: This was added to enable performance optimations where bias skip_bias_add: This was added to enable performance optimization where bias
can be fused with other elementwise operations. we skip can be fused with other elementwise operations. We skip
adding bias but instead return it. adding bias but instead return it.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment