Unverified Commit 7eb47b0f authored by Chunyuan WU's avatar Chunyuan WU Committed by GitHub
Browse files

[CPU] [BF16] Call fused_experts_cpu, weight_packed_linear and bmm_cpu kernel...


[CPU] [BF16] Call fused_experts_cpu, weight_packed_linear and bmm_cpu kernel in DeepSeek model (#6641)
Co-authored-by: default avatarThien Tran <gau.nernst@yahoo.com.sg>
parent bc2e5645
...@@ -30,7 +30,12 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -30,7 +30,12 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import (
_process_weight_after_loading,
cpu_has_amx_support,
is_cpu,
set_weight_attrs,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -52,6 +57,9 @@ WEIGHT_LOADER_V2_SUPPORTED = [ ...@@ -52,6 +57,9 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"IPEXAWQLinearMethod", "IPEXAWQLinearMethod",
] ]
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
def adjust_marlin_shard(param, shard_size, shard_offset): def adjust_marlin_shard(param, shard_size, shard_offset):
marlin_tile_size = getattr(param, "marlin_tile_size", None) marlin_tile_size = getattr(param, "marlin_tile_size", None)
...@@ -165,6 +173,10 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -165,6 +173,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs) set_weight_attrs(weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _is_cpu and _is_cpu_amx_available:
_process_weight_after_loading(layer, ["weight"])
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -172,6 +184,11 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -172,6 +184,11 @@ class UnquantizedLinearMethod(LinearMethodBase):
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if getattr(layer, "use_intel_amx_backend", False):
return torch.ops.sgl_kernel.weight_packed_linear(
x, layer.weight, bias, True # is_vnni
)
return F.linear(x, layer.weight, bias) return F.linear(x, layer.weight, bias)
......
...@@ -442,11 +442,20 @@ class LogitsProcessor(nn.Module): ...@@ -442,11 +442,20 @@ class LogitsProcessor(nn.Module):
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
if hasattr(lm_head, "weight"): if hasattr(lm_head, "weight"):
logits = torch.matmul( if getattr(lm_head, "use_intel_amx_backend", False):
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T logits = torch.ops.sgl_kernel.weight_packed_linear(
) hidden_states.to(lm_head.weight.dtype),
lm_head.weight,
None, # bias
True, # is_vnni
)
else:
logits = torch.matmul(
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
)
else: else:
# GGUF models # GGUF models
# TODO: use weight_packed_linear for GGUF models
logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias) logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
if self.logit_scale is not None: if self.logit_scale is not None:
......
...@@ -77,8 +77,15 @@ def moe_forward_native( ...@@ -77,8 +77,15 @@ def moe_forward_native(
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if apply_router_weight_on_input:
raise NotImplementedError()
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
......
...@@ -18,7 +18,14 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -18,7 +18,14 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs from sglang.srt.utils import (
_process_weight_after_loading,
cpu_has_amx_support,
get_bool_env_var,
is_cpu,
is_hip,
set_weight_attrs,
)
if torch.cuda.is_available(): if torch.cuda.is_available():
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
...@@ -28,6 +35,8 @@ else: ...@@ -28,6 +35,8 @@ else:
import logging import logging
_is_hip = is_hip() _is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _use_aiter: if _use_aiter:
...@@ -117,6 +126,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -117,6 +126,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Pack weight for get better performance on CPU
if _is_cpu and _is_cpu_amx_available:
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
return return
def apply( def apply(
...@@ -248,19 +262,64 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -248,19 +262,64 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return moe_forward_native( assert activation == "silu", f"activation = {activation} is not supported."
layer,
x, if (
use_grouped_topk, getattr(layer, "use_intel_amx_backend", False)
top_k, and not apply_router_weight_on_input
router_logits, ):
renormalize, topk_weights, topk_ids = select_experts(
topk_group, hidden_states=x,
num_expert_group, router_logits=router_logits,
num_fused_shared_experts, use_grouped_topk=use_grouped_topk,
custom_routing_function, top_k=top_k,
correction_bias, renormalize=renormalize,
) topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights.to(
torch.float
), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
topk_ids,
True, # inplace
False, # use_int8_w8a8
False, # use_fp8_w8a16
None, # w1_scale
None, # w2_scale
None, # block_size
None, # a1_scale
None, # a2_scale
True, # is_vnni
)
else:
return moe_forward_native(
layer,
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
activation,
apply_router_weight_on_input,
inplace,
no_combine,
routed_scaling_factor,
)
def forward_tpu(self, *args, **kwargs) -> torch.Tensor: def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("The TPU backend currently does not support MoE.") raise NotImplementedError("The TPU backend currently does not support MoE.")
......
...@@ -20,10 +20,18 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -20,10 +20,18 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
method_has_implemented_embedding, method_has_implemented_embedding,
) )
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import (
PackWeightMethod,
cpu_has_amx_support,
is_cpu,
set_weight_attrs,
)
DEFAULT_VOCAB_PADDING_SIZE = 64 DEFAULT_VOCAB_PADDING_SIZE = 64
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
class UnquantizedEmbeddingMethod(QuantizeMethodBase): class UnquantizedEmbeddingMethod(QuantizeMethodBase):
"""Unquantized method for embeddings.""" """Unquantized method for embeddings."""
...@@ -549,6 +557,11 @@ class ParallelLMHead(VocabParallelEmbedding): ...@@ -549,6 +557,11 @@ class ParallelLMHead(VocabParallelEmbedding):
use_presharded_weights=use_presharded_weights, use_presharded_weights=use_presharded_weights,
) )
self.quant_config = quant_config self.quant_config = quant_config
# We only support pack LMHead if it's not quantized. For LMHead with quant_config, the weight_name will be "qweight"
if self.quant_config is None and _is_cpu and _is_cpu_amx_available:
self.quant_method = PackWeightMethod(weight_names=["weight"])
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition, dtype=params_dtype) torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)
......
...@@ -93,6 +93,7 @@ from sglang.srt.utils import ( ...@@ -93,6 +93,7 @@ from sglang.srt.utils import (
BumpAllocator, BumpAllocator,
DeepEPMode, DeepEPMode,
LazyValue, LazyValue,
PackWeightMethod,
add_prefix, add_prefix,
bind_or_assign, bind_or_assign,
cpu_has_amx_support, cpu_has_amx_support,
...@@ -144,6 +145,9 @@ class AttnForwardMethod(IntEnum): ...@@ -144,6 +145,9 @@ class AttnForwardMethod(IntEnum):
# Use MLA but with fused RoPE # Use MLA but with fused RoPE
MLA_FUSED_ROPE = auto() MLA_FUSED_ROPE = auto()
# Use MLA with fused RoPE kernel for CPU
MLA_FUSED_ROPE_CPU = auto()
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
def __init__( def __init__(
...@@ -212,8 +216,18 @@ class MoEGate(nn.Module): ...@@ -212,8 +216,18 @@ class MoEGate(nn.Module):
) )
else: else:
self.e_score_correction_bias = None self.e_score_correction_bias = None
if _is_cpu and _is_cpu_amx_available:
self.quant_method = PackWeightMethod(weight_names=["weight"])
def forward(self, hidden_states): def forward(self, hidden_states):
if getattr(self, "use_intel_amx_backend", False):
return torch.ops.sgl_kernel.weight_packed_linear(
hidden_states,
self.weight,
None, # bias
True, # is_vnni
)
logits = F.linear(hidden_states, self.weight, None) logits = F.linear(hidden_states, self.weight, None)
return logits return logits
...@@ -778,6 +792,37 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -778,6 +792,37 @@ class DeepseekV2AttentionMLA(nn.Module):
"SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192 "SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
) )
# If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel
# which requires self.w_kc and self.w_vc to be packed.
# If not, we will use torch.bmm and weight shouldn't be packed in this case
if (
hasattr(self, "fused_qkv_a_proj_with_mqa")
and _is_cpu
and _is_cpu_amx_available
):
self.quant_method = PackWeightMethod(
weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
)
self.qkv_proj_with_rope_is_int8 = (
hasattr(self, "fused_qkv_a_proj_with_mqa")
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
)
self.qkv_proj_with_rope_is_fp8 = (
hasattr(self, "fused_qkv_a_proj_with_mqa")
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
)
self.weight_block_size = None
if self.qkv_proj_with_rope_is_fp8:
assert (
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
== self.q_b_proj.quant_method.quant_config.weight_block_size
)
self.weight_block_size = (
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
)
def dispatch_attn_forward_method( def dispatch_attn_forward_method(
self, forward_batch: ForwardBatch self, forward_batch: ForwardBatch
) -> AttnForwardMethod: ) -> AttnForwardMethod:
...@@ -791,7 +836,12 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -791,7 +836,12 @@ class DeepseekV2AttentionMLA(nn.Module):
else: else:
return AttnForwardMethod.MLA return AttnForwardMethod.MLA
else: else:
return AttnForwardMethod.MLA if hasattr(self, "fused_qkv_a_proj_with_mqa") and getattr(
self, "use_intel_amx_backend", False
):
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
else:
return AttnForwardMethod.MLA
if self.attention_backend == "flashinfer": if self.attention_backend == "flashinfer":
# Flashinfer MLA: Do not absorb when enabling ragged prefill # Flashinfer MLA: Do not absorb when enabling ragged prefill
...@@ -905,6 +955,10 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -905,6 +955,10 @@ class DeepseekV2AttentionMLA(nn.Module):
inner_state = self.forward_absorb_fused_mla_rope_prepare( inner_state = self.forward_absorb_fused_mla_rope_prepare(
positions, hidden_states, forward_batch, zero_allocator positions, hidden_states, forward_batch, zero_allocator
) )
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
inner_state = self.forward_absorb_fused_mla_rope_cpu_prepare(
positions, hidden_states, forward_batch, zero_allocator
)
else: else:
raise NotImplementedError raise NotImplementedError
return None, attn_forward_method, forward_batch, inner_state return None, attn_forward_method, forward_batch, inner_state
...@@ -924,6 +978,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -924,6 +978,8 @@ class DeepseekV2AttentionMLA(nn.Module):
return self.forward_absorb_core(*inner_state) return self.forward_absorb_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE: elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
return self.forward_absorb_fused_mla_rope_core(*inner_state) return self.forward_absorb_fused_mla_rope_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
return self.forward_absorb_fused_mla_rope_cpu_core(*inner_state)
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -1241,6 +1297,57 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1241,6 +1297,57 @@ class DeepseekV2AttentionMLA(nn.Module):
zero_allocator, zero_allocator,
) )
def forward_absorb_fused_mla_rope_cpu_prepare(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
):
assert self.q_lora_rank is not None and getattr(
self, "use_intel_amx_backend", False
), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"
q_input, k_input, v_input = (
torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight(
hidden_states,
self.fused_qkv_a_proj_with_mqa.weight,
self.q_b_proj.weight,
self.w_kc,
self.q_a_layernorm.weight,
self.kv_a_layernorm.weight,
positions,
self.rotary_emb.cos_sin_cache,
self.kv_a_layernorm.variance_epsilon,
self.qkv_proj_with_rope_is_int8,
self.qkv_proj_with_rope_is_fp8,
(
self.fused_qkv_a_proj_with_mqa.weight_scale
if self.qkv_proj_with_rope_is_int8
else (
self.fused_qkv_a_proj_with_mqa.weight_scale_inv
if self.qkv_proj_with_rope_is_fp8
else None
)
),
(
self.q_b_proj.weight_scale
if self.qkv_proj_with_rope_is_int8
else (
self.q_b_proj.weight_scale_inv
if self.qkv_proj_with_rope_is_fp8
else None
)
),
True, # is_vnni
self.weight_block_size,
self.q_lora_rank,
self.kv_lora_rank,
self.qk_rope_head_dim,
)
)
return (q_input, k_input, v_input, forward_batch, zero_allocator)
def forward_absorb_fused_mla_rope_core( def forward_absorb_fused_mla_rope_core(
self, self,
q_input, q_input,
...@@ -1314,6 +1421,43 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1314,6 +1421,43 @@ class DeepseekV2AttentionMLA(nn.Module):
return output return output
def forward_absorb_fused_mla_rope_cpu_core(
self, q_input, k_input, v_input, forward_batch, zero_allocator
):
assert self.q_lora_rank is not None and getattr(
self, "use_intel_amx_backend", False
), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
# [Note] Align shapes of bmm inputs.
# Shapes of inputs:
# q_nope: [M, B, K]
# original self.w_kc: [B, K, N]
# current self.w_kc (which has been converted in PackWeightMethod): [B, N, K]
# Shapes of inputs to sgl_kernel.cpu.bmm:
# out: [B, M, N]
# mat1: [B, M, K]
# mat2: [B, N, K]
B = self.w_vc.size(0)
N = self.w_vc.size(1)
M = attn_output.size(0)
output = torch.empty([M, int(B * N)], dtype=attn_output.dtype)
attn_bmm_output = output.view([M, B, N]).transpose_(0, 1)
torch.ops.sgl_kernel.bmm_cpu(
attn_bmm_output,
attn_output.transpose(0, 1),
self.w_vc,
True, # is_vnni
None, # scale
)
attn_output = output
output, _ = self.o_proj(attn_output)
return output
def _chunked_prefix_attn_mha( def _chunked_prefix_attn_mha(
self, self,
q: torch.Tensor, q: torch.Tensor,
......
...@@ -2457,6 +2457,77 @@ def cpu_has_amx_support(): ...@@ -2457,6 +2457,77 @@ def cpu_has_amx_support():
return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
def prepack_weight_if_needed(weight):
if weight.device != torch.device("cpu"):
return weight
if not cpu_has_amx_support():
return weight
return torch.ops.sgl_kernel.convert_weight_packed(weight)
# TODO: currently gemm kernel has the below requirements:
# OC % TILE_N == 0, where TILE_N = 16
# IC % TILE_K == 0, where TILE_K = 32
def dim_is_supported(weight):
return weight.size(0) % 16 == 0 and weight.size(1) % 32 == 0
def _process_weight_after_loading(module, weight_names, transpose_dims=None) -> None:
# Pack weight for get better performance on CPU
devices = {getattr(module, weight_name).device for weight_name in weight_names}
assert len(devices) == 1, f"Expects all weights to be on the same device"
device = devices.pop()
if transpose_dims:
assert len(weight_names) == len(
transpose_dims
), "len(weight_names) should be equal to len(transpose_dims)"
for i, weight_name in enumerate(weight_names):
weight_tensor = getattr(module, weight_name)
# We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
if not dim_is_supported(weight_tensor):
logger.warning(
f"Expects weight.size(0) % 16 == 0 and weight.size(1) % 32 == 0 "
f"but {weight_tensor.size(0)=} and {weight_tensor.size(1)=} in {module}. "
f"{module} won't use intel amx backend."
)
module.use_intel_amx_backend = False
return
if transpose_dims and transpose_dims[i]:
weight_tensor = weight_tensor.transpose(*transpose_dims[i])
packed_weight = torch.nn.Parameter(
prepack_weight_if_needed(weight_tensor),
requires_grad=False,
)
packed_weight.__dict__ = weight_tensor.__dict__
setattr(module, weight_name, packed_weight)
module.use_intel_amx_backend = (
device == torch.device("cpu") and cpu_has_amx_support()
)
if (
module.use_intel_amx_backend
and hasattr(module, "bias")
and module.bias is not None
):
module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False)
class PackWeightMethod:
def __init__(self, weight_names, transpose_dims=None):
self.weight_names = weight_names
self.transpose_dims = transpose_dims
def process_weights_after_loading(self, module) -> None:
_process_weight_after_loading(module, self.weight_names, self.transpose_dims)
class LazyValue: class LazyValue:
def __init__(self, creator: Callable): def __init__(self, creator: Callable):
self._creator = creator self._creator = creator
......
...@@ -318,8 +318,8 @@ void weight_packed_linear_kernel_impl( ...@@ -318,8 +318,8 @@ void weight_packed_linear_kernel_impl(
const int64_t MB = div_up(M, BLOCK_M); const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N); const int64_t NB = div_up(N, BLOCK_N);
// use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx // use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx c) N is small
const bool use_brgemm = (M > 4) || (!std::is_same_v<scalar_t, at::BFloat16>); const bool use_brgemm = (M > 4) || (!std::is_same_v<scalar_t, at::BFloat16>) || (N < 64);
// parallel on [MB, NB] // parallel on [MB, NB]
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
......
...@@ -28,7 +28,7 @@ class Mod(nn.Module): ...@@ -28,7 +28,7 @@ class Mod(nn.Module):
class TestGemm(CustomTestCase): class TestGemm(CustomTestCase):
M = [1, 101] M = [1, 101]
N = [32 * 13] N = [16, 32 * 13]
K = [32 * 16] K = [32 * 16]
has_bias = [False, True] has_bias = [False, True]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment