Unverified Commit e4e55af7 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Wav2Vec2-Conf / LLaMA] Style fix (#26188)

* torch.nn -> nn

* fix llama

* copies
parent 8b5da9fc
...@@ -99,7 +99,7 @@ class OpenLlamaRMSNorm(nn.Module): ...@@ -99,7 +99,7 @@ class OpenLlamaRMSNorm(nn.Module):
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->OpenLlama # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->OpenLlama
class OpenLlamaRotaryEmbedding(torch.nn.Module): class OpenLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
......
...@@ -89,7 +89,7 @@ class LlamaRMSNorm(nn.Module): ...@@ -89,7 +89,7 @@ class LlamaRMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype) return self.weight * hidden_states.to(input_dtype)
class LlamaRotaryEmbedding(torch.nn.Module): class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
......
...@@ -584,7 +584,7 @@ class Wav2Vec2ConformerConvolutionModule(nn.Module): ...@@ -584,7 +584,7 @@ class Wav2Vec2ConformerConvolutionModule(nn.Module):
if (config.conv_depthwise_kernel_size - 1) % 2 == 1: if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding") raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
self.layer_norm = nn.LayerNorm(config.hidden_size) self.layer_norm = nn.LayerNorm(config.hidden_size)
self.pointwise_conv1 = torch.nn.Conv1d( self.pointwise_conv1 = nn.Conv1d(
config.hidden_size, config.hidden_size,
2 * config.hidden_size, 2 * config.hidden_size,
kernel_size=1, kernel_size=1,
...@@ -592,8 +592,8 @@ class Wav2Vec2ConformerConvolutionModule(nn.Module): ...@@ -592,8 +592,8 @@ class Wav2Vec2ConformerConvolutionModule(nn.Module):
padding=0, padding=0,
bias=False, bias=False,
) )
self.glu = torch.nn.GLU(dim=1) self.glu = nn.GLU(dim=1)
self.depthwise_conv = torch.nn.Conv1d( self.depthwise_conv = nn.Conv1d(
config.hidden_size, config.hidden_size,
config.hidden_size, config.hidden_size,
config.conv_depthwise_kernel_size, config.conv_depthwise_kernel_size,
...@@ -602,9 +602,9 @@ class Wav2Vec2ConformerConvolutionModule(nn.Module): ...@@ -602,9 +602,9 @@ class Wav2Vec2ConformerConvolutionModule(nn.Module):
groups=config.hidden_size, groups=config.hidden_size,
bias=False, bias=False,
) )
self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size) self.batch_norm = nn.BatchNorm1d(config.hidden_size)
self.activation = ACT2FN[config.hidden_act] self.activation = ACT2FN[config.hidden_act]
self.pointwise_conv2 = torch.nn.Conv1d( self.pointwise_conv2 = nn.Conv1d(
config.hidden_size, config.hidden_size,
config.hidden_size, config.hidden_size,
kernel_size=1, kernel_size=1,
...@@ -612,7 +612,7 @@ class Wav2Vec2ConformerConvolutionModule(nn.Module): ...@@ -612,7 +612,7 @@ class Wav2Vec2ConformerConvolutionModule(nn.Module):
padding=0, padding=0,
bias=False, bias=False,
) )
self.dropout = torch.nn.Dropout(config.conformer_conv_dropout) self.dropout = nn.Dropout(config.conformer_conv_dropout)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
...@@ -798,7 +798,7 @@ class Wav2Vec2ConformerEncoderLayer(nn.Module): ...@@ -798,7 +798,7 @@ class Wav2Vec2ConformerEncoderLayer(nn.Module):
# Self-Attention # Self-Attention
self.self_attn_layer_norm = nn.LayerNorm(embed_dim) self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
self.self_attn_dropout = torch.nn.Dropout(dropout) self.self_attn_dropout = nn.Dropout(dropout)
self.self_attn = Wav2Vec2ConformerSelfAttention(config) self.self_attn = Wav2Vec2ConformerSelfAttention(config)
# Conformer Convolution # Conformer Convolution
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment