Commit e47baefc authored by renzhc's avatar renzhc
Browse files

fixed shape error

parent c6714fc3
Pipeline #3195 failed with stages
in 0 seconds
......@@ -2034,15 +2034,29 @@ class AllegroAttnProcessor2_0:
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
# query = attn.to_q(hidden_states)
# DCU OPT: TN->NN
if attn.to_q.bias:
query = torch.matmul(hidden_states, attn.to_q.weight.data) + attn.to_q.bias.data
else:
query = torch.matmul(hidden_states, attn.to_q.weight.data)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# key = attn.to_k(encoder_hidden_states)
# value = attn.to_v(encoder_hidden_states)
# DCU OPT: TN->NN
if attn.to_k.bias:
key = torch.matmul(encoder_hidden_states, attn.to_k.weight.data) + attn.to_k.bias.data
else:
key = torch.matmul(encoder_hidden_states, attn.to_k.weight.data)
if attn.to_v.bias:
value = torch.matmul(encoder_hidden_states, attn.to_v.weight.data) + attn.to_v.bias.data
else:
value = torch.matmul(encoder_hidden_states, attn.to_v.weight.data)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
......@@ -2068,7 +2082,9 @@ class AllegroAttnProcessor2_0:
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# hidden_states = attn.to_out[0](hidden_states)
# DCU OPT: TN->NN
hidden_states = torch.matmul(hidden_states, attn.to_out[0].weight.data) + attn.to_out[0].bias.data
# dropout
hidden_states = attn.to_out[1](hidden_states)
......@@ -2103,9 +2119,24 @@ class AuraFlowAttnProcessor2_0:
batch_size = hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# query = attn.to_q(hidden_states)
# key = attn.to_k(hidden_states)
# value = attn.to_v(hidden_states)
# DCU OPT: TN->NN
if isinstance(attn.to_q.bias, torch.Tensor):
query = torch.matmul(hidden_states, attn.to_q.weight.data) + attn.to_q.bias.data
else:
query = torch.matmul(hidden_states, attn.to_q.weight.data)
if isinstance(attn.to_k.bias, torch.Tensor):
key = torch.matmul(hidden_states, attn.to_k.weight.data) + attn.to_k.bias.data
else:
key = torch.matmul(hidden_states, attn.to_k.weight.data)
if isinstance(attn.to_v.bias, torch.Tensor):
value = torch.matmul(hidden_states, attn.to_v.weight.data) + attn.to_v.bias.data
else:
value = torch.matmul(hidden_states, attn.to_v.weight.data)
# `context` projections.
if encoder_hidden_states is not None:
......@@ -2164,7 +2195,9 @@ class AuraFlowAttnProcessor2_0:
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# hidden_states = attn.to_out[0](hidden_states)
# DCU OPT: TN->NN
hidden_states = torch.matmul(hidden_states, attn.to_out[0].weight.data) + attn.to_out[0].bias.data
# dropout
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
......@@ -2740,7 +2773,7 @@ class AttnProcessor2_0:
# query = attn.to_q(hidden_states)
# DCU OPT: TN->NN
if attn.to_q.bias:
if isinstance(attn.to_q.bias, torch.Tensor):
query = torch.matmul(hidden_states, attn.to_q.weight.data) + attn.to_q.bias.data
else:
query = torch.matmul(hidden_states, attn.to_q.weight.data)
......@@ -2753,11 +2786,11 @@ class AttnProcessor2_0:
# key = attn.to_k(encoder_hidden_states)
# value = attn.to_v(encoder_hidden_states)
# DCU OPT: TN->NN
if attn.to_k.bias:
if isinstance(attn.to_k.bias, torch.Tensor):
key = torch.matmul(encoder_hidden_states, attn.to_k.weight.data) + attn.to_k.bias.data
else:
key = torch.matmul(encoder_hidden_states, attn.to_k.weight.data)
if attn.to_v.bias:
if isinstance(attn.to_v.bias, torch.Tensor):
value = torch.matmul(encoder_hidden_states, attn.to_v.weight.data) + attn.to_v.bias.data
else:
value = torch.matmul(encoder_hidden_states, attn.to_v.weight.data)
......@@ -2765,24 +2798,9 @@ class AttnProcessor2_0:
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
# query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# DCU OPT: TN->NN
if isinstance(attn.to_q.bias, torch.Tensor):
query = torch.matmul(hidden_states, attn.to_q.weight.data) + attn.to_q.bias.data
else:
query = torch.matmul(hidden_states, attn.to_q.weight.data)
if isinstance(attn.to_k.bias, torch.Tensor):
key = torch.matmul(hidden_states, attn.to_k.weight.data) + attn.to_k.bias.data
else:
key = torch.matmul(hidden_states, attn.to_k.weight.data)
if isinstance(attn.to_v.bias, torch.Tensor):
value = torch.matmul(hidden_states, attn.to_v.weight.data) + attn.to_v.bias.data
else:
value = torch.matmul(hidden_states, attn.to_v.weight.data)
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
......
......@@ -307,47 +307,51 @@ def load_model_dict_into_meta(
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
# DCU OPT: TN->NN
for param_name, param in model.named_parameters():
if 'weight' in param_name and 'add_embedding.linear_1' in param_name:
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("rzc test error")
if 'weight' in param_name and 'add_embedding.linear_2' in param_name:
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("rzc test error")
if 'weight' in param_name and 'ff.net' in param_name:
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("lijian test error")
if 'weight' in param_name and 'time_emb_proj' in param_name:
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("lijian test error")
if 'weight' in param_name and 'attn' in param_name and ('to_q' in param_name or 'to_k' in param_name or 'to_v' in param_name or 'to_out' in param_name):
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("lijian test error")
if 'weight' in param_name and 'time_embedding' in param_name and ('linear_1' in param_name or 'linear_2' in param_name):
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("transpose weight to NN error")
if 'weight' in param_name and 'attentions' in param_name and ('proj_in' in param_name or 'proj_out' in param_name):
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("transpose weight to NN error")
if 'weight' in param_name and 'decoder.mid_block.attentions.0' in param_name and ('to_q' in param_name or 'to_k' in param_name or 'to_v' in param_name or 'to_out' in param_name):
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("lijian test error")
# add sxx TN->NN
for param_name, param in model.named_parameters():
if 'weight' in param_name and 'add_embedding.linear_1' in param_name:
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("rzc test error")
if 'weight' in param_name and 'add_embedding.linear_2' in param_name:
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("rzc test error")
if 'weight' in param_name and 'ff.net' in param_name:
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("lijian test error")
if 'weight' in param_name and 'time_emb_proj' in param_name:
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("lijian test error")
if 'weight' in param_name and 'attn' in param_name and ('to_q' in param_name or 'to_k' in param_name or 'to_v' in param_name or 'to_out' in param_name):
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
#continue
raise ValueError("lijian test error")
if 'weight' in param_name and 'time_embedding' in param_name and ('linear_1' in param_name or 'linear_2' in param_name):
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("transpose weight to NN error")
if 'weight' in param_name and 'attentions' in param_name and ('proj_in' in param_name or 'proj_out' in param_name):
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("transpose weight to NN error")
if 'weight' in param_name and 'decoder.mid_block.attentions.0' in param_name and ('to_q' in param_name or 'to_k' in param_name or 'to_v' in param_name or 'to_out' in param_name):
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
#continue
raise ValueError("lijian test error")
return offload_index, state_dict_index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment