Commit a8ecd3d7 authored by Rick Ho's avatar Rick Ho
Browse files

remove debug output and todo for replicated mp input

parent 01ae2d72
...@@ -24,7 +24,6 @@ class FMoELinear(nn.Module): ...@@ -24,7 +24,6 @@ class FMoELinear(nn.Module):
class FMoENaiveGate(nn.Module): class FMoENaiveGate(nn.Module):
def __init__(self, d_model, num_expert, world_size, top_k=2): def __init__(self, d_model, num_expert, world_size, top_k=2):
super(FMoENaiveGate, self).__init__() super(FMoENaiveGate, self).__init__()
# print(f"gate: {num_expert * world_size}")
self.gate = nn.Linear(d_model, num_expert * world_size) self.gate = nn.Linear(d_model, num_expert * world_size)
self.top_k = top_k self.top_k = top_k
...@@ -92,7 +91,6 @@ class FMoETransformerMLP(nn.Module): ...@@ -92,7 +91,6 @@ class FMoETransformerMLP(nn.Module):
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden) self.htoh4 = FMoELinear(num_expert, d_model, d_hidden)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model) self.h4toh = FMoELinear(num_expert, d_hidden, d_model)
# print(f"FMoETransformerMLP world_size: {world_size} num_expert: {num_expert}")
self.gate = FMoENaiveGate(d_model, num_expert, world_size, top_k) self.gate = FMoENaiveGate(d_model, num_expert, world_size, top_k)
self.layer_norm = nn.LayerNorm(d_model) self.layer_norm = nn.LayerNorm(d_model)
...@@ -107,8 +105,6 @@ class FMoETransformerMLP(nn.Module): ...@@ -107,8 +105,6 @@ class FMoETransformerMLP(nn.Module):
batch_start = local_batch_size * self.model_parallel_rank batch_start = local_batch_size * self.model_parallel_rank
batch_end = min(batch_start + local_batch_size, B) batch_end = min(batch_start + local_batch_size, B)
inp = inp[:, batch_start:batch_end, :].contiguous() inp = inp[:, batch_start:batch_end, :].contiguous()
# print(inp.shape)
# print(f"mp_rank: {self.model_parallel_rank}, [{batch_start}, {batch_end})")
residual = inp residual = inp
if self.pre_lnorm: if self.pre_lnorm:
...@@ -116,7 +112,6 @@ class FMoETransformerMLP(nn.Module): ...@@ -116,7 +112,6 @@ class FMoETransformerMLP(nn.Module):
gate_top_k_idx, gate_score = self.gate(inp) gate_top_k_idx, gate_score = self.gate(inp)
# TODO: merge replication into local_scatter
inp = inp.view(-1, self.d_model).repeat_interleave( inp = inp.view(-1, self.d_model).repeat_interleave(
repeats=self.top_k, dim=0 repeats=self.top_k, dim=0
) # (BxLxtop_k) x d_model ) # (BxLxtop_k) x d_model
......
#!/bin/bash #!/bin/bash
if [ ! -z $OMPI_COMM_WORLD_LOCAL_RANK ] if [ ! -z $OMPI_COMM_WORLD_LOCAL_RANK ]
then then
export CUDA_VISIBLE_DEVICES=$OMPI_COMM_WORLD_LOCAL_RANK export CUDA_VISIBLE_DEVICES=$OMPI_COMM_WORLD_LOCAL_RANK
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment