Commit e47baefc authored by renzhc's avatar renzhc
Browse files

fixed shape error

parent c6714fc3
Pipeline #3195 failed with stages
in 0 seconds
...@@ -2034,15 +2034,29 @@ class AllegroAttnProcessor2_0: ...@@ -2034,15 +2034,29 @@ class AllegroAttnProcessor2_0:
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) # query = attn.to_q(hidden_states)
# DCU OPT: TN->NN
if attn.to_q.bias:
query = torch.matmul(hidden_states, attn.to_q.weight.data) + attn.to_q.bias.data
else:
query = torch.matmul(hidden_states, attn.to_q.weight.data)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.norm_cross: elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) # key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) # value = attn.to_v(encoder_hidden_states)
# DCU OPT: TN->NN
if attn.to_k.bias:
key = torch.matmul(encoder_hidden_states, attn.to_k.weight.data) + attn.to_k.bias.data
else:
key = torch.matmul(encoder_hidden_states, attn.to_k.weight.data)
if attn.to_v.bias:
value = torch.matmul(encoder_hidden_states, attn.to_v.weight.data) + attn.to_v.bias.data
else:
value = torch.matmul(encoder_hidden_states, attn.to_v.weight.data)
inner_dim = key.shape[-1] inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads head_dim = inner_dim // attn.heads
...@@ -2068,7 +2082,9 @@ class AllegroAttnProcessor2_0: ...@@ -2068,7 +2082,9 @@ class AllegroAttnProcessor2_0:
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states) # hidden_states = attn.to_out[0](hidden_states)
# DCU OPT: TN->NN
hidden_states = torch.matmul(hidden_states, attn.to_out[0].weight.data) + attn.to_out[0].bias.data
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
...@@ -2103,9 +2119,24 @@ class AuraFlowAttnProcessor2_0: ...@@ -2103,9 +2119,24 @@ class AuraFlowAttnProcessor2_0:
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
# `sample` projections. # `sample` projections.
query = attn.to_q(hidden_states) # query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states) # key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states) # value = attn.to_v(hidden_states)
# DCU OPT: TN->NN
if isinstance(attn.to_q.bias, torch.Tensor):
query = torch.matmul(hidden_states, attn.to_q.weight.data) + attn.to_q.bias.data
else:
query = torch.matmul(hidden_states, attn.to_q.weight.data)
if isinstance(attn.to_k.bias, torch.Tensor):
key = torch.matmul(hidden_states, attn.to_k.weight.data) + attn.to_k.bias.data
else:
key = torch.matmul(hidden_states, attn.to_k.weight.data)
if isinstance(attn.to_v.bias, torch.Tensor):
value = torch.matmul(hidden_states, attn.to_v.weight.data) + attn.to_v.bias.data
else:
value = torch.matmul(hidden_states, attn.to_v.weight.data)
# `context` projections. # `context` projections.
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
...@@ -2164,7 +2195,9 @@ class AuraFlowAttnProcessor2_0: ...@@ -2164,7 +2195,9 @@ class AuraFlowAttnProcessor2_0:
) )
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states) # hidden_states = attn.to_out[0](hidden_states)
# DCU OPT: TN->NN
hidden_states = torch.matmul(hidden_states, attn.to_out[0].weight.data) + attn.to_out[0].bias.data
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
...@@ -2740,7 +2773,7 @@ class AttnProcessor2_0: ...@@ -2740,7 +2773,7 @@ class AttnProcessor2_0:
# query = attn.to_q(hidden_states) # query = attn.to_q(hidden_states)
# DCU OPT: TN->NN # DCU OPT: TN->NN
if attn.to_q.bias: if isinstance(attn.to_q.bias, torch.Tensor):
query = torch.matmul(hidden_states, attn.to_q.weight.data) + attn.to_q.bias.data query = torch.matmul(hidden_states, attn.to_q.weight.data) + attn.to_q.bias.data
else: else:
query = torch.matmul(hidden_states, attn.to_q.weight.data) query = torch.matmul(hidden_states, attn.to_q.weight.data)
...@@ -2753,11 +2786,11 @@ class AttnProcessor2_0: ...@@ -2753,11 +2786,11 @@ class AttnProcessor2_0:
# key = attn.to_k(encoder_hidden_states) # key = attn.to_k(encoder_hidden_states)
# value = attn.to_v(encoder_hidden_states) # value = attn.to_v(encoder_hidden_states)
# DCU OPT: TN->NN # DCU OPT: TN->NN
if attn.to_k.bias: if isinstance(attn.to_k.bias, torch.Tensor):
key = torch.matmul(encoder_hidden_states, attn.to_k.weight.data) + attn.to_k.bias.data key = torch.matmul(encoder_hidden_states, attn.to_k.weight.data) + attn.to_k.bias.data
else: else:
key = torch.matmul(encoder_hidden_states, attn.to_k.weight.data) key = torch.matmul(encoder_hidden_states, attn.to_k.weight.data)
if attn.to_v.bias: if isinstance(attn.to_v.bias, torch.Tensor):
value = torch.matmul(encoder_hidden_states, attn.to_v.weight.data) + attn.to_v.bias.data value = torch.matmul(encoder_hidden_states, attn.to_v.weight.data) + attn.to_v.bias.data
else: else:
value = torch.matmul(encoder_hidden_states, attn.to_v.weight.data) value = torch.matmul(encoder_hidden_states, attn.to_v.weight.data)
...@@ -2765,24 +2798,9 @@ class AttnProcessor2_0: ...@@ -2765,24 +2798,9 @@ class AttnProcessor2_0:
inner_dim = key.shape[-1] inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads head_dim = inner_dim // attn.heads
# query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# DCU OPT: TN->NN
if isinstance(attn.to_q.bias, torch.Tensor):
query = torch.matmul(hidden_states, attn.to_q.weight.data) + attn.to_q.bias.data
else:
query = torch.matmul(hidden_states, attn.to_q.weight.data)
if isinstance(attn.to_k.bias, torch.Tensor):
key = torch.matmul(hidden_states, attn.to_k.weight.data) + attn.to_k.bias.data
else:
key = torch.matmul(hidden_states, attn.to_k.weight.data)
if isinstance(attn.to_v.bias, torch.Tensor):
value = torch.matmul(hidden_states, attn.to_v.weight.data) + attn.to_v.bias.data
else:
value = torch.matmul(hidden_states, attn.to_v.weight.data)
if attn.norm_q is not None: if attn.norm_q is not None:
query = attn.norm_q(query) query = attn.norm_q(query)
......
...@@ -307,6 +307,7 @@ def load_model_dict_into_meta( ...@@ -307,6 +307,7 @@ def load_model_dict_into_meta(
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs) set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
# DCU OPT: TN->NN # DCU OPT: TN->NN
# add sxx TN->NN
for param_name, param in model.named_parameters(): for param_name, param in model.named_parameters():
if 'weight' in param_name and 'add_embedding.linear_1' in param_name: if 'weight' in param_name and 'add_embedding.linear_1' in param_name:
if param.data.dim() == 2: if param.data.dim() == 2:
...@@ -318,6 +319,7 @@ def load_model_dict_into_meta( ...@@ -318,6 +319,7 @@ def load_model_dict_into_meta(
param.data = param.data.permute(1, 0).contiguous() param.data = param.data.permute(1, 0).contiguous()
else: else:
raise ValueError("rzc test error") raise ValueError("rzc test error")
if 'weight' in param_name and 'ff.net' in param_name: if 'weight' in param_name and 'ff.net' in param_name:
if param.data.dim() == 2: if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous() param.data = param.data.permute(1, 0).contiguous()
...@@ -332,6 +334,7 @@ def load_model_dict_into_meta( ...@@ -332,6 +334,7 @@ def load_model_dict_into_meta(
if param.data.dim() == 2: if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous() param.data = param.data.permute(1, 0).contiguous()
else: else:
#continue
raise ValueError("lijian test error") raise ValueError("lijian test error")
if 'weight' in param_name and 'time_embedding' in param_name and ('linear_1' in param_name or 'linear_2' in param_name): if 'weight' in param_name and 'time_embedding' in param_name and ('linear_1' in param_name or 'linear_2' in param_name):
if param.data.dim() == 2: if param.data.dim() == 2:
...@@ -347,6 +350,7 @@ def load_model_dict_into_meta( ...@@ -347,6 +350,7 @@ def load_model_dict_into_meta(
if param.data.dim() == 2: if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous() param.data = param.data.permute(1, 0).contiguous()
else: else:
#continue
raise ValueError("lijian test error") raise ValueError("lijian test error")
return offload_index, state_dict_index return offload_index, state_dict_index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment