Unverified Commit 6a895197 authored by Jiayi Yan's avatar Jiayi Yan Committed by GitHub
Browse files

[Bugfix][CI] fix typos (#34934)


Signed-off-by: default avatar1195343015 <1195343015@qq.com>
Signed-off-by: default avatarJiayi Yan <66017932+1195343015@users.noreply.github.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 8c760b6a
...@@ -1502,10 +1502,10 @@ class RowParallelLinear(LinearBase): ...@@ -1502,10 +1502,10 @@ class RowParallelLinear(LinearBase):
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
splitted_input = split_tensor_along_last_dim( split_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size input_, num_partitions=self.tp_size
) )
input_parallel = splitted_input[self.tp_rank].contiguous() input_parallel = split_input[self.tp_rank].contiguous()
# Matrix multiply. # Matrix multiply.
assert self.quant_method is not None assert self.quant_method is not None
......
...@@ -35,7 +35,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -35,7 +35,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
"""Pluggable MLA layer which allows OOT backends to add """Pluggable MLA layer which allows OOT backends to add
custom implementations of the outer MLA layer (including rope & o_proj). custom implementations of the outer MLA layer (including rope & o_proj).
Note that currently oot platforms can still use CustomOp.register_oot to Note that currently oot platforms can still use CustomOp.register_oot to
replace MLA layer entirly, although we use PluggableLayer to register replace MLA layer entirely, although we use PluggableLayer to register
this layer now. this layer now.
This class takes positions and hidden_states as input. This class takes positions and hidden_states as input.
......
...@@ -191,7 +191,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -191,7 +191,7 @@ class CompressedTensorsConfig(QuantizationConfig):
""" """
Helper function to update target_scheme_map Helper function to update target_scheme_map
since linear layers get fused into FusedMoE since linear layers get fused into FusedMoE
targetting 'Linear' needs to also match targeting 'Linear' needs to also match
FusedMoE modules. FusedMoE modules.
""" """
if ( if (
......
...@@ -2445,7 +2445,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -2445,7 +2445,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
w2_scale=layer.w2_weight_scale, # group scale w2_scale=layer.w2_weight_scale, # group scale
g1_alphas=layer.w13_weight_chan_scale, g1_alphas=layer.w13_weight_chan_scale,
g2_alphas=layer.w2_weight_chan_scale, g2_alphas=layer.w2_weight_chan_scale,
per_act_token_quant=True, # always use dynamc per-token per_act_token_quant=True, # always use dynamic per-token
per_out_ch_quant=True, # always use per-channel per_out_ch_quant=True, # always use per-channel
) )
......
...@@ -261,7 +261,7 @@ class CPUAWQLinearMethod(LinearMethodBase): ...@@ -261,7 +261,7 @@ class CPUAWQLinearMethod(LinearMethodBase):
zeros = pack_cols(zeros, bits, group_num, output_size).contiguous() zeros = pack_cols(zeros, bits, group_num, output_size).contiguous()
# make 16 output channel as a block and transpose to # make 16 output channel as a block and transpose to
# the make the block contigous # the make the block contiguous
weight = pack_cols(weight, bits, input_size, output_size) weight = pack_cols(weight, bits, input_size, output_size)
weight = ( weight = (
weight.view(input_size, -1, 16 // pack_factor) weight.view(input_size, -1, 16 // pack_factor)
......
...@@ -199,7 +199,7 @@ class TorchAOConfig(QuantizationConfig): ...@@ -199,7 +199,7 @@ class TorchAOConfig(QuantizationConfig):
@classmethod @classmethod
def from_config_dict_json(cls, config_dict_json: str) -> "TorchAOConfig": def from_config_dict_json(cls, config_dict_json: str) -> "TorchAOConfig":
"""Iniitalize class from a config_dict json string, got from """Initialize class from a config_dict json string, got from
torchao_config_object = some AOBaseConfig object torchao_config_object = some AOBaseConfig object
json.dumps(config_to_dict(torchao_config_object)) json.dumps(config_to_dict(torchao_config_object))
""" """
......
...@@ -255,7 +255,7 @@ def _flashinfer_fp8_blockscale_gemm_impl( ...@@ -255,7 +255,7 @@ def _flashinfer_fp8_blockscale_gemm_impl(
This batch-size-dependent selection is essential for maintaining model accuracy. This batch-size-dependent selection is essential for maintaining model accuracy.
Benchmarks on GSM8K show a significant accuracy gap (88% vs 95%) for DeepSeek-V3.1 Benchmarks on GSM8K show a significant accuracy gap (88% vs 95%) for DeepSeek-V3.1
when using FlashInfer's DeepGEMM on M>=32. The M < 32 strategy fixes the accurracy when using FlashInfer's DeepGEMM on M>=32. The M < 32 strategy fixes the accuracy
drop. drop.
Args: Args:
......
...@@ -39,7 +39,7 @@ def query_machete_supported_group_sizes(act_type: torch.dtype) -> list[int]: ...@@ -39,7 +39,7 @@ def query_machete_supported_group_sizes(act_type: torch.dtype) -> list[int]:
def check_machete_supports_shape( def check_machete_supports_shape(
in_features: int, out_featrues: int in_features: int, out_features: int
) -> tuple[bool, str | None]: ) -> tuple[bool, str | None]:
if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0:
return ( return (
...@@ -47,7 +47,7 @@ def check_machete_supports_shape( ...@@ -47,7 +47,7 @@ def check_machete_supports_shape(
"Input features size must be divisible by " "Input features size must be divisible by "
f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}", f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}",
) )
if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0: if out_features % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0:
return ( return (
False, False,
"Output features size must be divisible by " "Output features size must be divisible by "
......
...@@ -237,7 +237,7 @@ class ApplyRotaryEmb(CustomOp): ...@@ -237,7 +237,7 @@ class ApplyRotaryEmb(CustomOp):
Arguments of apply_rotary_emb() in vllm_flash_attn: Arguments of apply_rotary_emb() in vllm_flash_attn:
x: [batch_size, seq_len, nheads, headdim] x: [batch_size, seq_len, nheads, headdim]
cos, sin: [seqlen_rotary, rotary_dim / 2] cos, sin: [seqlen_rotary, rotary_dim / 2]
interleaved: defalut as False (Neox-style). interleaved: default as False (Neox-style).
... ...
""" """
interleaved = not self.is_neox_style interleaved = not self.is_neox_style
...@@ -259,7 +259,7 @@ class ApplyRotaryEmb(CustomOp): ...@@ -259,7 +259,7 @@ class ApplyRotaryEmb(CustomOp):
Arguments of apply_rotary() in flash_attn: Arguments of apply_rotary() in flash_attn:
x: [batch_size, seq_len, nheads, headdim] x: [batch_size, seq_len, nheads, headdim]
cos, sin: [seqlen_rotary, rotary_dim / 2] cos, sin: [seqlen_rotary, rotary_dim / 2]
interleaved: defalut as False (Neox-style). interleaved: default as False (Neox-style).
... ...
""" """
interleaved = not self.is_neox_style interleaved = not self.is_neox_style
......
...@@ -342,7 +342,7 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -342,7 +342,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
visual_token_mask = visual_token_mask.repeat(1, self.hidden_size).bool() visual_token_mask = visual_token_mask.repeat(1, self.hidden_size).bool()
text_token_mask = ~visual_token_mask text_token_mask = ~visual_token_mask
final_experts_hidden_states = torch.zeros_like(hidden_states) final_experts_hidden_states = torch.zeros_like(hidden_states)
final_shared_ouput = ( final_shared_output = (
torch.zeros_like(hidden_states) if self.has_shared_experts else None torch.zeros_like(hidden_states) if self.has_shared_experts else None
) )
...@@ -356,26 +356,26 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -356,26 +356,26 @@ class Ernie4_5_VLMoeMoE(nn.Module):
text_router_logits, _ = self.text_experts_gate( text_router_logits, _ = self.text_experts_gate(
text_hidden_states.to(dtype=torch.float32) text_hidden_states.to(dtype=torch.float32)
) )
text_shared_ouput, text_experts_output = self.text_experts( text_shared_output, text_experts_output = self.text_experts(
hidden_states=text_hidden_states, router_logits=text_router_logits hidden_states=text_hidden_states, router_logits=text_router_logits
) )
final_experts_hidden_states[text_token_mask] = text_experts_output.flatten() final_experts_hidden_states[text_token_mask] = text_experts_output.flatten()
if self.has_shared_experts: if self.has_shared_experts:
final_shared_ouput[text_token_mask] = text_shared_ouput.flatten() final_shared_output[text_token_mask] = text_shared_output.flatten()
vision_router_logits, _ = self.vision_experts_gate( vision_router_logits, _ = self.vision_experts_gate(
vision_hidden_states.to(dtype=torch.float32) vision_hidden_states.to(dtype=torch.float32)
) )
vision_shared_ouput, vision_experts_output = self.vision_experts( vision_shared_output, vision_experts_output = self.vision_experts(
hidden_states=vision_hidden_states, router_logits=vision_router_logits hidden_states=vision_hidden_states, router_logits=vision_router_logits
) )
final_experts_hidden_states[visual_token_mask] = ( final_experts_hidden_states[visual_token_mask] = (
vision_experts_output.flatten() vision_experts_output.flatten()
) )
if self.has_shared_experts: if self.has_shared_experts:
final_shared_ouput[visual_token_mask] = vision_shared_ouput.flatten() final_shared_output[visual_token_mask] = vision_shared_output.flatten()
final_hidden_states = (final_shared_ouput, final_experts_hidden_states) final_hidden_states = (final_shared_output, final_experts_hidden_states)
else: else:
# only text modal input # only text modal input
text_router_logits, _ = self.text_experts_gate( text_router_logits, _ = self.text_experts_gate(
......
...@@ -107,7 +107,7 @@ class Conv2dSubsampling(nn.Module): ...@@ -107,7 +107,7 @@ class Conv2dSubsampling(nn.Module):
) )
self.subsampling = 4 self.subsampling = 4
left_context = right_context = 3 # both exclude currect frame left_context = right_context = 3 # both exclude current frame
self.context = left_context + 1 + right_context # 7 self.context = left_context + 1 + right_context # 7
def forward( def forward(
......
...@@ -115,7 +115,7 @@ class EncoderLayerSANM(nn.Module): ...@@ -115,7 +115,7 @@ class EncoderLayerSANM(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
mask: torch.Tensor | None = None, mask: torch.Tensor | None = None,
cache=None, cache=None,
mask_shfit_chunk=None, mask_shift_chunk=None,
mask_att_chunk_encoder=None, mask_att_chunk_encoder=None,
): ):
residual = hidden_states residual = hidden_states
...@@ -125,14 +125,14 @@ class EncoderLayerSANM(nn.Module): ...@@ -125,14 +125,14 @@ class EncoderLayerSANM(nn.Module):
hidden_states = residual + self.self_attn( hidden_states = residual + self.self_attn(
hidden_states, hidden_states,
mask, mask,
mask_shfit_chunk=mask_shfit_chunk, mask_shift_chunk=mask_shift_chunk,
mask_att_chunk_encoder=mask_att_chunk_encoder, mask_att_chunk_encoder=mask_att_chunk_encoder,
) )
else: else:
hidden_states = self.self_attn( hidden_states = self.self_attn(
hidden_states, hidden_states,
mask, mask,
mask_shfit_chunk=mask_shfit_chunk, mask_shift_chunk=mask_shift_chunk,
mask_att_chunk_encoder=mask_att_chunk_encoder, mask_att_chunk_encoder=mask_att_chunk_encoder,
) )
...@@ -140,7 +140,7 @@ class EncoderLayerSANM(nn.Module): ...@@ -140,7 +140,7 @@ class EncoderLayerSANM(nn.Module):
hidden_states = self.norm2(hidden_states) hidden_states = self.norm2(hidden_states)
hidden_states = residual + self.feed_forward(hidden_states) hidden_states = residual + self.feed_forward(hidden_states)
return hidden_states, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder return hidden_states, mask, cache, mask_shift_chunk, mask_att_chunk_encoder
class MultiHeadedAttentionSANM(nn.Module): class MultiHeadedAttentionSANM(nn.Module):
...@@ -183,13 +183,13 @@ class MultiHeadedAttentionSANM(nn.Module): ...@@ -183,13 +183,13 @@ class MultiHeadedAttentionSANM(nn.Module):
self, self,
inputs: torch.Tensor, inputs: torch.Tensor,
mask: torch.Tensor, mask: torch.Tensor,
mask_shfit_chunk: torch.Tensor = None, mask_shift_chunk: torch.Tensor = None,
): ):
b, t, d = inputs.size() b, t, d = inputs.size()
if mask is not None: if mask is not None:
mask = torch.reshape(mask, (b, -1, 1)) mask = torch.reshape(mask, (b, -1, 1))
if mask_shfit_chunk is not None: if mask_shift_chunk is not None:
mask = mask * mask_shfit_chunk mask = mask * mask_shift_chunk
inputs = inputs * mask inputs = inputs * mask
x = inputs.transpose(1, 2) x = inputs.transpose(1, 2)
...@@ -243,11 +243,11 @@ class MultiHeadedAttentionSANM(nn.Module): ...@@ -243,11 +243,11 @@ class MultiHeadedAttentionSANM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
mask: torch.Tensor, mask: torch.Tensor,
mask_shfit_chunk: torch.Tensor = None, mask_shift_chunk: torch.Tensor = None,
mask_att_chunk_encoder: torch.Tensor = None, mask_att_chunk_encoder: torch.Tensor = None,
): ):
q_h, k_h, v_h, v = self.forward_qkv(hidden_states) q_h, k_h, v_h, v = self.forward_qkv(hidden_states)
fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk) fsmn_memory = self.forward_fsmn(v, mask, mask_shift_chunk)
q_h = q_h * self.d_k ** (-0.5) q_h = q_h * self.d_k ** (-0.5)
scores = torch.matmul(q_h, k_h.transpose(-2, -1)) scores = torch.matmul(q_h, k_h.transpose(-2, -1))
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
......
...@@ -646,7 +646,7 @@ class IsaacImageProcessor: ...@@ -646,7 +646,7 @@ class IsaacImageProcessor:
return_tensors: str | TensorType | None, return_tensors: str | TensorType | None,
**kwargs: Unpack[IsaacImageProcessorKwargs], **kwargs: Unpack[IsaacImageProcessorKwargs],
) -> BatchFeature: ) -> BatchFeature:
"""Preprocess images into format compatibile with vLLM input processing.""" """Preprocess images into format compatible with vLLM input processing."""
all_pixel_values: list[torch.Tensor] = [] all_pixel_values: list[torch.Tensor] = []
all_image_grids: list[torch.Tensor] = [] all_image_grids: list[torch.Tensor] = []
......
...@@ -299,7 +299,7 @@ class KeyeVisionEmbeddings(nn.Module): ...@@ -299,7 +299,7 @@ class KeyeVisionEmbeddings(nn.Module):
) )
( (
batch_size, batch_size,
squence_len, sequence_len,
channel, channel,
height, height,
width, width,
......
...@@ -238,7 +238,7 @@ class LongcatRouter(nn.Module): ...@@ -238,7 +238,7 @@ class LongcatRouter(nn.Module):
self, self,
config: FlashConfig, config: FlashConfig,
zero_expert_num: int, zero_expert_num: int,
rounter_params_dtype: torch.dtype, router_params_dtype: torch.dtype,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -252,12 +252,12 @@ class LongcatRouter(nn.Module): ...@@ -252,12 +252,12 @@ class LongcatRouter(nn.Module):
config.hidden_size, config.hidden_size,
self.n_routed_experts, self.n_routed_experts,
bias=config.router_bias, bias=config.router_bias,
params_dtype=rounter_params_dtype, params_dtype=router_params_dtype,
quant_config=None, quant_config=None,
prefix=f"{prefix}.classifier", prefix=f"{prefix}.classifier",
) )
self.e_score_correction_bias = nn.Parameter( self.e_score_correction_bias = nn.Parameter(
torch.zeros((self.n_routed_experts), dtype=rounter_params_dtype) torch.zeros((self.n_routed_experts), dtype=router_params_dtype)
) )
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -281,14 +281,14 @@ class LongcatMoe(nn.Module): ...@@ -281,14 +281,14 @@ class LongcatMoe(nn.Module):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
# Gate always runs at half / full precision for now. # Gate always runs at half / full precision for now.
self.rounter_params_dtype = params_dtype self.router_params_dtype = params_dtype
if config.router_dtype == "float32": if config.router_dtype == "float32":
self.rounter_params_dtype = torch.float32 self.router_params_dtype = torch.float32
self.router = LongcatRouter( self.router = LongcatRouter(
config=config, config=config,
zero_expert_num=config.zero_expert_num, zero_expert_num=config.zero_expert_num,
rounter_params_dtype=self.rounter_params_dtype, router_params_dtype=self.router_params_dtype,
prefix=f"{prefix}.gate", prefix=f"{prefix}.gate",
) )
...@@ -309,7 +309,7 @@ class LongcatMoe(nn.Module): ...@@ -309,7 +309,7 @@ class LongcatMoe(nn.Module):
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
enable_eplb=enable_eplb, enable_eplb=enable_eplb,
routed_scaling_factor=config.routed_scaling_factor, routed_scaling_factor=config.routed_scaling_factor,
router_logits_dtype=self.rounter_params_dtype, router_logits_dtype=self.router_params_dtype,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -329,7 +329,7 @@ class LongcatMoe(nn.Module): ...@@ -329,7 +329,7 @@ class LongcatMoe(nn.Module):
hidden_states_padded = hidden_states hidden_states_padded = hidden_states
router_logits_full = self.router( router_logits_full = self.router(
hidden_states_padded.to(self.rounter_params_dtype) hidden_states_padded.to(self.router_params_dtype)
) )
# ZeroExpertFusedMoE handles routing memoization and zero expert computation # ZeroExpertFusedMoE handles routing memoization and zero expert computation
......
...@@ -1321,14 +1321,14 @@ def get_image_size(image: ImageInput) -> ImageSize: ...@@ -1321,14 +1321,14 @@ def get_image_size(image: ImageInput) -> ImageSize:
raise ValueError(f"Unknown image type: {type(image)}") raise ValueError(f"Unknown image type: {type(image)}")
def exif_tranpose( def exif_transpose(
images: ImageInput | None, images: ImageInput | None,
) -> ImageInput | None: ) -> ImageInput | None:
if images is None: if images is None:
return None return None
if images is not None and isinstance(images, (list, tuple)): if images is not None and isinstance(images, (list, tuple)):
images = [ images = [
exif_tranpose(img) if isinstance(img, Image) else img for img in images exif_transpose(img) if isinstance(img, Image) else img for img in images
] ]
elif images is not None and isinstance(images, Image): elif images is not None and isinstance(images, Image):
images = ImageOps.exif_transpose(images) images = ImageOps.exif_transpose(images)
...@@ -1667,7 +1667,7 @@ class Molmo2ProcessorWrapper: ...@@ -1667,7 +1667,7 @@ class Molmo2ProcessorWrapper:
**kwargs: object, **kwargs: object,
) -> BatchFeature: ) -> BatchFeature:
inputs = [text] inputs = [text]
images = exif_tranpose(images) images = exif_transpose(images)
if getattr(self.processor, "image_processor", None) is not None: if getattr(self.processor, "image_processor", None) is not None:
inputs.append(images) inputs.append(images)
if getattr(self.processor, "video_processor", None) is not None: if getattr(self.processor, "video_processor", None) is not None:
...@@ -2352,7 +2352,7 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]): ...@@ -2352,7 +2352,7 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
def get_image_replacement_molmo2(item_idx: int) -> list[int]: def get_image_replacement_molmo2(item_idx: int) -> list[int]:
images = mm_items.get_items("image", ImageProcessorItems) images = mm_items.get_items("image", ImageProcessorItems)
image = images.get(item_idx) image = images.get(item_idx)
image = exif_tranpose(image) image = exif_transpose(image)
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=False) resize_nrows, resize_cols = processor.get_base_grid_size(is_video=False)
if use_single_crop_col_tokens is not None: if use_single_crop_col_tokens is not None:
......
...@@ -349,7 +349,7 @@ class NemotronHMoEDecoderLayer(nn.Module): ...@@ -349,7 +349,7 @@ class NemotronHMoEDecoderLayer(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
# Get per-layer config for heterogeneous models if exsist # Get per-layer config for heterogeneous models if exists
get_layer_config = getattr(config, "get_nemotron_h_config_for_layer", None) get_layer_config = getattr(config, "get_nemotron_h_config_for_layer", None)
layer_config = get_layer_config(layer_idx) if get_layer_config else config layer_config = get_layer_config(layer_idx) if get_layer_config else config
...@@ -517,7 +517,7 @@ class NemotronHAttentionDecoderLayer(nn.Module): ...@@ -517,7 +517,7 @@ class NemotronHAttentionDecoderLayer(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
# Get per-layer config for heterogeneous models if exsist # Get per-layer config for heterogeneous models if exists
get_layer_config = getattr(config, "get_nemotron_h_config_for_layer", None) get_layer_config = getattr(config, "get_nemotron_h_config_for_layer", None)
layer_config = get_layer_config(layer_idx) if get_layer_config else config layer_config = get_layer_config(layer_idx) if get_layer_config else config
......
...@@ -486,7 +486,7 @@ class SiglipVisionEmbeddings(nn.Module): ...@@ -486,7 +486,7 @@ class SiglipVisionEmbeddings(nn.Module):
) )
( (
batch_size, batch_size,
squence_len, sequence_len,
channel, channel,
height, height,
width, width,
......
...@@ -689,19 +689,19 @@ class ConformerEncoder(TransformerEncoderBase): ...@@ -689,19 +689,19 @@ class ConformerEncoder(TransformerEncoderBase):
default False. default False.
ext_pw_out_channel: int, optional ext_pw_out_channel: int, optional
the number of channel for CNN the number of channel for CNN
before depthwise_seperable_CNN. before depthwise_separable_CNN.
If 0 then use linear. default 0. If 0 then use linear. default 0.
ext_pw_kernel_size: int, optional ext_pw_kernel_size: int, optional
kernel size of N before depthwise_seperable_CNN. kernel size of N before depthwise_separable_CNN.
only work for ext_pw_out_channel > 0. only work for ext_pw_out_channel > 0.
default 1 default 1
depthwise_seperable_out_channel: int, optional depthwise_seperable_out_channel: int, optional
the number of channel for the number of channel for
depthwise_seperable_CNN. depthwise_separable_CNN.
default 256. default 256.
depthwise_multiplier: int, optional depthwise_multiplier: int, optional
the number of multiplier for the number of multiplier for
depthwise_seperable_CNN. depthwise_separable_CNN.
default 1. default 1.
chunk_se: int, optional chunk_se: int, optional
0 for offline SE. 0 for offline SE.
...@@ -711,7 +711,7 @@ class ConformerEncoder(TransformerEncoderBase): ...@@ -711,7 +711,7 @@ class ConformerEncoder(TransformerEncoderBase):
by only the current chunk. by only the current chunk.
default 0. default 0.
kernel_size: int, optional kernel_size: int, optional
the number of kernels for depthwise_seperable_CNN. the number of kernels for depthwise_separable_CNN.
default 3. default 3.
activation: str, optional activation: str, optional
FeedForward block activation. FeedForward block activation.
...@@ -721,7 +721,7 @@ class ConformerEncoder(TransformerEncoderBase): ...@@ -721,7 +721,7 @@ class ConformerEncoder(TransformerEncoderBase):
activation function used in ConvModule part activation function used in ConvModule part
of the conformer, default "relu". of the conformer, default "relu".
conv_glu_type: str, optional conv_glu_type: str, optional
activation used use glu in depthwise_seperable_CNN, activation used use glu in depthwise_separable_CNN,
default "sigmoid" default "sigmoid"
bias_in_glu: bool, optional bias_in_glu: bool, optional
if set to True, use additive bias in the weight module if set to True, use additive bias in the weight module
......
...@@ -217,8 +217,8 @@ class GLUPointWiseConv(nn.Module): ...@@ -217,8 +217,8 @@ class GLUPointWiseConv(nn.Module):
return x return x
class DepthWiseSeperableConv1d(nn.Module): class DepthWiseSeparableConv1d(nn.Module):
"""DepthWiseSeperableConv1d module used in Convnet module """DepthWiseSeparableConv1d module used in ConvNet module
for the conformer, for more details see: for the conformer, for more details see:
https://arxiv.org/pdf/2005.08100v1.pdf https://arxiv.org/pdf/2005.08100v1.pdf
...@@ -390,7 +390,7 @@ class ConvModule(nn.Module): ...@@ -390,7 +390,7 @@ class ConvModule(nn.Module):
else: else:
padding = (kernel_size - 1) // 2 padding = (kernel_size - 1) // 2
self.dw_sep_conv_1d = DepthWiseSeperableConv1d( self.dw_sep_conv_1d = DepthWiseSeparableConv1d(
input_dim, input_dim,
depthwise_seperable_out_channel, depthwise_seperable_out_channel,
kernel_size, kernel_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment