Unverified Commit 76e05301 authored by Anton Vlasjuk's avatar Anton Vlasjuk Committed by GitHub
Browse files

Fix a shape annotation and typos in `mamba` slow forward (#30691)

* fix typos and one shape comment

* fix `intermediade` typo in jamba
parent e6708709
...@@ -962,15 +962,15 @@ class JambaMambaMixer(nn.Module): ...@@ -962,15 +962,15 @@ class JambaMambaMixer(nn.Module):
# 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size] discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediade_size, seq_len, ssm_state_size] discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size]
deltaB_u = discrete_B * hidden_states[:, :, :, None].float() deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
# 3.c perform the recurrence y ← SSM(A, B, C)(x) # 3.c perform the recurrence y ← SSM(A, B, C)(x)
scan_outputs = [] scan_outputs = []
for i in range(seq_len): for i in range(seq_len):
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state] ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state]
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1] scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1]
scan_outputs.append(scan_output[:, :, 0]) scan_outputs.append(scan_output[:, :, 0])
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediade_size, seq_len] scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len]
scan_output = scan_output + (hidden_states * self.D[None, :, None]) scan_output = scan_output + (hidden_states * self.D[None, :, None])
scan_output = (scan_output * self.act(gate)) scan_output = (scan_output * self.act(gate))
......
...@@ -279,16 +279,16 @@ class MambaMixer(nn.Module): ...@@ -279,16 +279,16 @@ class MambaMixer(nn.Module):
# 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size] discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediade_size, seq_len, ssm_state_size] discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size]
deltaB_u = discrete_B * hidden_states[:, :, :, None].float() deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
# 3.c perform the recurrence y ← SSM(A, B, C)(x) # 3.c perform the recurrence y ← SSM(A, B, C)(x)
scan_outputs = [] scan_outputs = []
for i in range(seq_len): for i in range(seq_len):
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state] ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state]
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1] scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1]
scan_outputs.append(scan_output[:, :, 0]) scan_outputs.append(scan_output[:, :, 0])
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size] scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len]
scan_output = scan_output + (hidden_states * self.D[None, :, None]) scan_output = scan_output + (hidden_states * self.D[None, :, None])
scan_output = (scan_output * self.act(gate)) scan_output = (scan_output * self.act(gate))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment