#ifndef WENETPARAMS_H #define WENETPARAMS_H // #pragma pack(1) #define vocab_size 5538 typedef struct { float conv0_weight[512 * 9]; float conv0_bias[512]; float conv1_weight[512 * 512 * 9]; float conv1_bias[512]; float out0_weight[9728 * 512]; float out0_bias[512]; } EncEmbedParams; typedef struct { float linear_q_weight[512 * 512]; float linear_q_bias[512]; float linear_k_weight[512 * 512]; float linear_k_bias[512]; float linear_v_weight[512 * 512]; float linear_v_bias[512]; float linear_out_weight[512 * 512]; float linear_out_bias[512]; } SelfAttnParams; typedef struct { SelfAttnParams linear0; float linear_pos_weight[512 * 512]; float pos_bias_u[512]; float pos_bias_v[512]; } EncSelfAttnParams; typedef struct { float w1_weight[512 * 2048]; float w1_bias[2048]; float w2_weight[2048 * 512]; float w2_bias[512]; } FeedForwardParams; typedef struct { float weight[512]; float bias[512]; } NormParams; typedef struct { float pointwise_conv1_weight[1024 * 512]; float pointwise_conv1_bias[1024]; float depthwise_conv_weight[512 * 15]; float depthwise_conv_bias[512]; float pointwise_conv2_weight[512 * 512]; float pointwise_conv2_bias[512]; NormParams norm; } EncConvParams; typedef struct { EncSelfAttnParams self_attn; FeedForwardParams feedforward; FeedForwardParams feedforward_macaron; EncConvParams conv_module; NormParams norm_ff; NormParams norm_mha; NormParams norm_macaron; NormParams norm_conv; NormParams norm_final; // float concat_weight[1024 * 512]; // float concat_bias[512]; } SubEncoderParams; typedef struct { EncEmbedParams embed; SubEncoderParams sub_encoder[12]; NormParams after_norm; } EncoderParams; typedef struct { SelfAttnParams self_attn; SelfAttnParams src_attn; FeedForwardParams feedward; NormParams norm1; NormParams norm2; NormParams norm3; // float concat_weight1[1024 * 512]; // float concat_bias1[512]; // float concat_weight2[1024 * 512]; // float concat_bias2[512]; } SubDecoderParams; typedef struct { float embed_weight[vocab_size * 512]; SubDecoderParams sub_decoder[6]; NormParams after_norm; float output_weight[vocab_size * 512]; float output_bias[vocab_size]; } DecoderParams; typedef struct { EncoderParams encoder; float ctc_weight[512 * vocab_size]; float ctc_bias[vocab_size]; DecoderParams decoder; } WenetParams; // #pragma pack() #endif