Unverified Commit f8a173bb authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Improve LoRA Perf by Deprecating FlashInfer and Eliminating Redundant Tensor Ops (#8940)

parent 6b847a9a
...@@ -35,7 +35,7 @@ ...@@ -35,7 +35,7 @@
"\n", "\n",
"* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n", "* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n",
"\n", "\n",
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n", "* `lora_backend`: The backend of running GEMM kernels for Lora modules. Currently we only support Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
"\n", "\n",
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n", "* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
"\n", "\n",
......
...@@ -5,22 +5,6 @@ import torch ...@@ -5,22 +5,6 @@ import torch
from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.lora.utils import LoRABatchInfo
def get_fuse_output_add_from_name(name: str) -> bool:
mapping = {
"triton": True,
"flashinfer": False,
}
return mapping.get(name, False)
def get_fuse_stacked_lora_b_from_name(name: str) -> bool:
mapping = {
"triton": True,
"flashinfer": False,
}
return mapping.get(name, False)
class BaseLoRABackend: class BaseLoRABackend:
"""Base class for different Lora backends. """Base class for different Lora backends.
Each backend has its own implementation of Lora kernels. Each backend has its own implementation of Lora kernels.
...@@ -28,15 +12,11 @@ class BaseLoRABackend: ...@@ -28,15 +12,11 @@ class BaseLoRABackend:
Args: Args:
name: name of backend name: name of backend
batch_info: information of current batch for use batch_info: information of current batch for use
fuse_output_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
and the operation of adding will be fused into kernel
""" """
def __init__(self, name: str, batch_info: LoRABatchInfo = None): def __init__(self, name: str, batch_info: LoRABatchInfo = None):
self.name = name self.name = name
self.batch_info = batch_info self.batch_info = batch_info
self.fuse_output_add = get_fuse_output_add_from_name(name)
self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
def run_lora_a_sgemm( def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
...@@ -126,8 +106,8 @@ def get_backend_from_name(name: str) -> BaseLoRABackend: ...@@ -126,8 +106,8 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
return TritonLoRABackend return TritonLoRABackend
elif name == "flashinfer": elif name == "flashinfer":
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend raise ValueError(
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."
return FlashInferLoRABackend )
else: else:
raise ValueError(f"Invalid backend: {name}") raise ValueError(f"Invalid backend: {name}")
from typing import Tuple
import torch
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available():
from flashinfer import SegmentGEMMWrapper
class FlashInferLoRABackend(BaseLoRABackend):
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
super().__init__(name, batch_info)
# Set up SGemm Wrapper from flashinfer
# FIXME wait for flashinfer segment gemm update
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return self.segment_gemm.run(
x=x,
weights=weights,
batch_size=self.batch_info.bs,
weight_column_major=True,
seg_indptr=self.batch_info.seg_indptr,
weight_indices=self.batch_info.weight_indices,
)
def run_lora_b_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return (
self.segment_gemm.run(
x=x,
weights=weights,
batch_size=self.batch_info.bs,
weight_column_major=True,
seg_indptr=self.batch_info.seg_indptr,
weight_indices=self.batch_info.weight_indices,
)
* self.batch_info.scalings[0]
)
def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: Tuple[torch.Tensor],
*args,
**kwargs,
) -> torch.Tensor:
assert isinstance(qkv_lora_b, tuple) and len(qkv_lora_b) == 2
# Shape of lora_a_output: (s, 3 * r)
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
q_lora_b, kv_lora_b = qkv_lora_b
lora_rank = kv_lora_b.shape[-1]
output_dim_q = q_lora_b.shape[-2]
output_dim_kv = kv_lora_b.shape[-2]
lora_output = torch.empty(
(x.shape[0], output_dim_q + 2 * output_dim_kv),
device=x.device,
dtype=x.dtype,
)
# q
lora_output[:, :output_dim_q] = self.run_lora_b_sgemm(
x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0]
)
# kv
lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = (
self.run_lora_b_sgemm(
x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(),
weights=kv_lora_b[0],
)
)
lora_output[
:, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv
] = self.run_lora_b_sgemm(
x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(),
weights=kv_lora_b[1],
)
return lora_output * self.batch_info.scalings[0]
def run_gate_up_lora(
self,
x: torch.Tensor,
gate_up_lora_a: torch.Tensor,
gate_up_lora_b: Tuple[torch.Tensor],
*args,
**kwargs,
) -> torch.Tensor:
assert isinstance(gate_up_lora_b, tuple) and len(gate_up_lora_b) == 2
lora_rank = gate_up_lora_b[0].shape[-1]
output_dim = gate_up_lora_b[0].shape[-2]
# Shape of lora_a_output: (s, 2 * r)
lora_a_output = self.run_lora_a_sgemm(x=x, weights=gate_up_lora_a)
lora_output = torch.empty(
(x.shape[0], 2 * output_dim),
device=x.device,
dtype=x.dtype,
)
# Compute lora for gate and up proj respectively
lora_output[:, :output_dim] = self.run_lora_b_sgemm(
x=lora_a_output[:, :lora_rank].contiguous(),
weights=gate_up_lora_b[0],
)
lora_output[:, output_dim:] = self.run_lora_b_sgemm(
x=lora_a_output[:, lora_rank:].contiguous(),
weights=gate_up_lora_b[1],
)
return lora_output * self.batch_info.scalings[0]
from typing import List, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -79,18 +77,13 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -79,18 +77,13 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self.B_buffer = B_buffer self.B_buffer = B_buffer
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output}
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer) lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm( lora_output = self.lora_backend.run_lora_b_sgemm(
lora_a_output, x=lora_a_output,
self.B_buffer[0], weights=self.B_buffer,
**backend_kwargs, base_output=base_output,
)
return (
lora_output
if self.lora_backend.fuse_output_add
else base_output + lora_output
) )
return lora_output
def forward(self, input_: torch.Tensor): def forward(self, input_: torch.Tensor):
# duplicate the logic in ColumnParallelLinear # duplicate the logic in ColumnParallelLinear
...@@ -135,37 +128,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -135,37 +128,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
): ):
self.set_lora = True self.set_lora = True
self.A_buffer_gate_up = A_buffer self.A_buffer_gate_up = A_buffer
if self.lora_backend.fuse_stacked_lora_b: self.B_buffer_gate_up = B_buffer
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
if getattr(self, "B_buffer_gate_up", None) is None:
self.B_buffer_gate_up = torch.empty(
(
B_buffer[0].shape[0],
2 * B_buffer[0].shape[1],
B_buffer[0].shape[2],
),
dtype=B_buffer[0].dtype,
device=B_buffer[0].device,
)
self.B_buffer_gate_up[:, : B_buffer[0].shape[1], :].copy_(B_buffer[0])
self.B_buffer_gate_up[:, B_buffer[0].shape[1] :, :].copy_(B_buffer[1])
else:
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output}
lora_output = self.lora_backend.run_gate_up_lora( lora_output = self.lora_backend.run_gate_up_lora(
x, x=x,
self.A_buffer_gate_up, gate_up_lora_a=self.A_buffer_gate_up,
self.B_buffer_gate_up, gate_up_lora_b=self.B_buffer_gate_up,
**backend_kwargs, base_output=base_output,
)
return (
lora_output
if self.lora_backend.fuse_output_add
else base_output + lora_output
) )
return lora_output
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
return A return A
...@@ -173,9 +145,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -173,9 +145,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
# Since the outputs for both gate and up are identical, we use a random one. # Since the outputs for both gate and up are identical, we use a random one.
shard_size = self.base_layer.output_partition_sizes[0] shard_size = self.base_layer.output_partition_sizes[0]
gate_size = self.base_layer.output_sizes[0]
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
end_idx = (tp_rank + 1) * shard_size end_idx = (tp_rank + 1) * shard_size
return B[:, start_idx:end_idx, :] return torch.concat(
(
B[start_idx:end_idx, :],
B[gate_size + start_idx : gate_size + end_idx],
),
dim=0,
)
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...@@ -185,86 +164,46 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -185,86 +164,46 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
lora_backend: BaseLoRABackend, lora_backend: BaseLoRABackend,
) -> None: ) -> None:
super().__init__(base_layer, lora_backend) super().__init__(base_layer, lora_backend)
q_proj_shard_size = self.base_layer.q_proj_shard_size
kv_proj_shard_size = self.base_layer.kv_proj_shard_size
self.output_offset = torch.tensor(
[
0,
q_proj_shard_size,
q_proj_shard_size + kv_proj_shard_size,
q_proj_shard_size + 2 * kv_proj_shard_size,
],
dtype=torch.int32,
device=next(self.base_layer.parameters()).device,
)
# For computing number of launched blocks
self.max_qkv_out_dim = max(q_proj_shard_size, kv_proj_shard_size)
def set_lora_info( def set_lora_info(
self, self,
A_buffer_qkv: torch.Tensor, A_buffer_qkv: torch.Tensor,
B_buffer_q: torch.Tensor, B_buffer_qkv: torch.Tensor,
B_buffer_kv: torch.Tensor,
): ):
self.set_lora = True self.set_lora = True
self.A_buffer_qkv = A_buffer_qkv self.A_buffer_qkv = A_buffer_qkv
self.B_buffer_qkv = B_buffer_qkv
if self.lora_backend.fuse_stacked_lora_b:
assert (
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
if getattr(self, "B_buffer_qkv", None) is None:
self.B_buffer_qkv = torch.empty(
(
B_buffer_q[0].shape[0],
output_dim_q + 2 * output_dim_kv,
B_buffer_q[0].shape[2],
),
dtype=B_buffer_q[0].dtype,
device=B_buffer_q[0].device,
)
self.B_buffer_qkv[:, :output_dim_q, :].copy_(B_buffer_q[0])
self.B_buffer_qkv[:, output_dim_q : output_dim_q + output_dim_kv, :].copy_(
B_buffer_kv[0]
)
self.B_buffer_qkv[:, output_dim_q + output_dim_kv :, :].copy_(
B_buffer_kv[1]
)
# Offsets of q/k/v in output dimension
if getattr(self, "output_offset", None) is None:
self.output_offset = torch.tensor(
[
0,
output_dim_q,
output_dim_q + output_dim_kv,
output_dim_q + 2 * output_dim_kv,
],
dtype=torch.int32,
device=B_buffer_q.device,
)
# For computing number of launched blocks
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
else:
self.B_buffer_qkv = (
B_buffer_q,
B_buffer_kv,
)
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output}
if self.lora_backend.fuse_stacked_lora_b:
backend_kwargs["output_offset"] = self.output_offset
backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
lora_output = self.lora_backend.run_qkv_lora( lora_output = self.lora_backend.run_qkv_lora(
x, x=x,
self.A_buffer_qkv, qkv_lora_a=self.A_buffer_qkv,
self.B_buffer_qkv, qkv_lora_b=self.B_buffer_qkv,
**backend_kwargs, base_output=base_output,
) output_offset=self.output_offset,
return ( max_qkv_out_dim=self.max_qkv_out_dim,
lora_output
if self.lora_backend.fuse_output_add
else base_output + lora_output
) )
return lora_output
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
return A return A
def slice_lora_b_weights( def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int) -> torch.Tensor:
self, B: List[torch.Tensor], tp_rank: int
) -> Tuple[torch.Tensor, torch.Tensor]:
B_q, B_kv = B
base_layer = self.base_layer base_layer = self.base_layer
q_proj_shard_size = base_layer.q_proj_shard_size q_proj_shard_size = base_layer.q_proj_shard_size
kv_proj_shard_size = base_layer.kv_proj_shard_size kv_proj_shard_size = base_layer.kv_proj_shard_size
...@@ -277,7 +216,19 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -277,7 +216,19 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
kv_start_idx = kv_proj_shard_size * kv_shard_id kv_start_idx = kv_proj_shard_size * kv_shard_id
kv_end_idx = kv_start_idx + kv_proj_shard_size kv_end_idx = kv_start_idx + kv_proj_shard_size
return B_q[q_start_idx:q_end_idx, :], B_kv[:, kv_start_idx:kv_end_idx, :] q_size, k_size, _ = base_layer.output_sizes
B_q_shard = B[q_start_idx:q_end_idx, :]
B_k_shard = B[q_size + kv_start_idx : q_size + kv_end_idx, :]
B_v_shard = B[q_size + k_size + kv_start_idx : q_size + k_size + kv_end_idx, :]
return torch.concat(
(
B_q_shard,
B_k_shard,
B_v_shard,
),
dim=0,
)
class RowParallelLinearWithLoRA(BaseLayerWithLoRA): class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...@@ -294,18 +245,13 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -294,18 +245,13 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self.B_buffer = B_buffer self.B_buffer = B_buffer
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output}
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer) lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm( lora_output = self.lora_backend.run_lora_b_sgemm(
lora_a_output, x=lora_a_output,
self.B_buffer[0], weights=self.B_buffer,
**backend_kwargs, base_output=base_output,
)
return (
lora_output
if self.lora_backend.fuse_output_add
else base_output + lora_output
) )
return lora_output
def forward(self, input_: torch.Tensor): def forward(self, input_: torch.Tensor):
# duplicate the logic in RowParallelLinear # duplicate the logic in RowParallelLinear
......
...@@ -117,7 +117,6 @@ class LoRAAdapter(nn.Module): ...@@ -117,7 +117,6 @@ class LoRAAdapter(nn.Module):
q_name = weight_name q_name = weight_name
k_name = weight_name.replace("q_proj", "k_proj") k_name = weight_name.replace("q_proj", "k_proj")
v_name = weight_name.replace("q_proj", "v_proj") v_name = weight_name.replace("q_proj", "v_proj")
kv_name = weight_name.replace("q_proj", "kv_proj")
qkv_name = weight_name.replace("q_proj", "qkv_proj") qkv_name = weight_name.replace("q_proj", "qkv_proj")
# If k_proj doesn't have lora, initialize it to zero # If k_proj doesn't have lora, initialize it to zero
...@@ -126,57 +125,27 @@ class LoRAAdapter(nn.Module): ...@@ -126,57 +125,27 @@ class LoRAAdapter(nn.Module):
if "k_proj" in target_module if "k_proj" in target_module
else torch.zeros_like(weights[v_name]) else torch.zeros_like(weights[v_name])
) )
if "lora_A" in weight_name: weights[qkv_name] = torch.cat(
weights[qkv_name] = torch.cat( (
( weights[q_name],
weights[q_name], k_proj_weight,
k_proj_weight, weights[v_name],
weights[v_name], ),
), 0,
0, )
) weights.pop(q_name)
weights.pop(q_name) if "k_proj" in target_module:
if "k_proj" in target_module: weights.pop(k_name)
weights.pop(k_name) weights.pop(v_name)
weights.pop(v_name)
else:
weights[kv_name] = torch.stack(
[
k_proj_weight,
weights[v_name],
],
dim=0,
)
if "k_proj" in target_module:
weights.pop(k_name)
weights.pop(v_name)
elif "qkv_proj" in weight_name: elif "qkv_proj" in weight_name:
# If qkv_proj is already stacked, we normalize it following the SGL convention. # If qkv_proj is already stacked, we normalize it following the SGL convention.
qkv_name = weight_name qkv_name = weight_name
q_name = weight_name.replace("qkv_proj", "q_proj") q_name = weight_name.replace("qkv_proj", "q_proj")
k_name = weight_name.replace("qkv_proj", "k_proj") k_name = weight_name.replace("qkv_proj", "k_proj")
v_name = weight_name.replace("qkv_proj", "v_proj") v_name = weight_name.replace("qkv_proj", "v_proj")
kv_name = weight_name.replace("qkv_proj", "kv_proj")
if "lora_A" in weight_name: if "lora_A" in weight_name:
weights[qkv_name] = weights[qkv_name].repeat(3, 1) weights[qkv_name] = weights[qkv_name].repeat(3, 1)
else: # else: no-op as LoRA B weight is already stacked.
head_size = (
self.base_hf_config.hidden_size
// self.base_hf_config.num_attention_heads
)
weights[q_name], k_proj_weight, v_proj_weight = torch.split(
weights[qkv_name],
[
head_size * self.base_hf_config.num_attention_heads,
head_size * self.base_hf_config.num_key_value_heads,
head_size * self.base_hf_config.num_key_value_heads,
],
dim=0,
)
weights[kv_name] = torch.stack(
[k_proj_weight, v_proj_weight],
dim=0,
)
def normalize_gate_up_proj( def normalize_gate_up_proj(
self, weight_names: List[str], weights: Dict[str, torch.Tensor] self, weight_names: List[str], weights: Dict[str, torch.Tensor]
...@@ -187,20 +156,14 @@ class LoRAAdapter(nn.Module): ...@@ -187,20 +156,14 @@ class LoRAAdapter(nn.Module):
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj") gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
if up_name not in weights: if up_name not in weights:
weights[up_name] = torch.zeros_like(weights[weight_name]) weights[up_name] = torch.zeros_like(weights[weight_name])
# FIXME: Add gate-only support for flashinfer in future implementations
assert self.lora_backend.name == "triton", ( assert self.lora_backend.name == "triton", (
f"LoRA weight initialization currently only supported for 'triton' backend. " f"LoRA weight initialization currently only supported for 'triton' backend. "
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration " f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
f"or consider implementing custom initialization logic for other backends." f"or consider implementing custom initialization logic for other backends."
) )
if "lora_A" in weight_name: weights[gate_up_name] = torch.cat(
weights[gate_up_name] = torch.cat( (weights[weight_name], weights[up_name]), 0
(weights[weight_name], weights[up_name]), 0 )
)
else:
weights[gate_up_name] = torch.stack(
[weights[weight_name], weights[up_name]], dim=0
)
weights.pop(weight_name) weights.pop(weight_name)
if up_name in weights: if up_name in weights:
weights.pop(up_name) weights.pop(up_name)
...@@ -209,12 +172,4 @@ class LoRAAdapter(nn.Module): ...@@ -209,12 +172,4 @@ class LoRAAdapter(nn.Module):
gate_up_name = weight_name gate_up_name = weight_name
if "lora_A" in weight_name: if "lora_A" in weight_name:
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1) weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
else: # else: no-op as LoRA B weight is already stacked.
output_dim = weights[gate_up_name].shape[0] // 2
weights[gate_up_name] = torch.stack(
[
weights[gate_up_name][:output_dim, :],
weights[gate_up_name][output_dim:, :],
],
dim=0,
)
...@@ -31,7 +31,6 @@ from sglang.srt.lora.mem_pool import LoRAMemoryPool ...@@ -31,7 +31,6 @@ from sglang.srt.lora.mem_pool import LoRAMemoryPool
from sglang.srt.lora.utils import ( from sglang.srt.lora.utils import (
LoRABatchInfo, LoRABatchInfo,
LoRAType, LoRAType,
get_customized_names_from_hf_names,
get_layer_id, get_layer_id,
get_normalized_lora_weight_names, get_normalized_lora_weight_names,
get_weight_name, get_weight_name,
...@@ -345,40 +344,19 @@ class LoRAManager: ...@@ -345,40 +344,19 @@ class LoRAManager:
) )
self.lora_backend.set_batch_info(batch_info) self.lora_backend.set_batch_info(batch_info)
# TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call
# this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch.
self.update_lora_info()
def update_lora_info(self): def update_lora_info(self):
""" """
Update all LoRA modules to associate them with the latest memory buffer. Update all LoRA modules to associate them with the latest memory buffer.
""" """
for layer_id, layer_modules in enumerate(self.lora_modules): for layer_id, layer_modules in enumerate(self.lora_modules):
for module_name, module in layer_modules.items(): for module_name, module in layer_modules.items():
if "qkv_proj" in module_name: weight_name = get_weight_name(
module.set_lora_info( module_name, self.memory_pool.lora_weight_names
self.memory_pool.get_tensor( )
"qkv_proj", layer_id, LoRAType.LORA_A module.set_lora_info(
), self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
self.memory_pool.get_tensor( self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
"q_proj", layer_id, LoRAType.LORA_B )
),
self.memory_pool.get_tensor(
"kv_proj", layer_id, LoRAType.LORA_B
),
)
else:
weight_name = get_weight_name(
module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
)
module.set_lora_info(
self.memory_pool.get_tensor(
weight_name, layer_id, LoRAType.LORA_A
),
self.memory_pool.get_tensor(
weight_name, layer_id, LoRAType.LORA_B
),
)
def init_state( def init_state(
self, self,
...@@ -405,6 +383,7 @@ class LoRAManager: ...@@ -405,6 +383,7 @@ class LoRAManager:
self.init_lora_weight_names() self.init_lora_weight_names()
self.init_lora_modules() self.init_lora_modules()
self.init_memory_pool() self.init_memory_pool()
self.update_lora_info()
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None): def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
# Configs of all active LoRA adapters, indexed by LoRA ID. # Configs of all active LoRA adapters, indexed by LoRA ID.
...@@ -461,9 +440,9 @@ class LoRAManager: ...@@ -461,9 +440,9 @@ class LoRAManager:
Add new LoRA weight names if needed based on the current `self.configs`. Add new LoRA weight names if needed based on the current `self.configs`.
""" """
# Target lora weight names for lora_a and lora_b modules respectively. self.lora_weight_names: Set[str] = get_normalized_lora_weight_names(
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules) self.target_modules
self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B)) )
def load_lora_weights(self, lora_ref: LoRARef): def load_lora_weights(self, lora_ref: LoRARef):
""" """
...@@ -479,15 +458,6 @@ class LoRAManager: ...@@ -479,15 +458,6 @@ class LoRAManager:
lora_adapter.initialize_weights() lora_adapter.initialize_weights()
self.loras[lora_ref.lora_id] = lora_adapter self.loras[lora_ref.lora_id] = lora_adapter
# Additional checks for flashinfer backend
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
if self.lora_backend == "flashinfer":
lora_dims = set(x.r for x in self.configs.values())
scalings = set(x.scaling for x in self.loras.values())
assert (
len(lora_dims) == 1 and len(scalings) == 1
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
def init_memory_pool(self): def init_memory_pool(self):
"""(Re)initialize the LoRA memory pool based on the current configurations.""" """(Re)initialize the LoRA memory pool based on the current configurations."""
self.memory_pool = LoRAMemoryPool( self.memory_pool = LoRAMemoryPool(
...@@ -512,12 +482,6 @@ class LoRAManager: ...@@ -512,12 +482,6 @@ class LoRAManager:
{} for _ in range(self.base_hf_config.num_hidden_layers) {} for _ in range(self.base_hf_config.num_hidden_layers)
] ]
# Target module names of customized layers defined in python/sglang/srt/layers
# e.g., {"qkv_proj", "o_proj"}
customized_target_names = get_customized_names_from_hf_names(
self.target_modules, self.base_model
)
for module_name, module in self.base_model.named_modules(): for module_name, module in self.base_model.named_modules():
# TODO (lifuhuang): in the future, we should consider generalizing the # TODO (lifuhuang): in the future, we should consider generalizing the
# should_apply_lora function to support mapping by full module name instead # should_apply_lora function to support mapping by full module name instead
...@@ -530,7 +494,7 @@ class LoRAManager: ...@@ -530,7 +494,7 @@ class LoRAManager:
continue continue
# The module should be converted if it is included in target_names # The module should be converted if it is included in target_names
if module_name.split(".")[-1] in customized_target_names: if module_name.split(".")[-1] in self.lora_weight_names:
layer_id = get_layer_id(module_name) layer_id = get_layer_id(module_name)
self.lora_modules[layer_id][module_name] = self.set_lora_module( self.lora_modules[layer_id][module_name] = self.set_lora_module(
module_name, module module_name, module
......
...@@ -52,7 +52,7 @@ class LoRAMemoryPool: ...@@ -52,7 +52,7 @@ class LoRAMemoryPool:
tp_size: int, tp_size: int,
tp_rank: int, tp_rank: int,
max_lora_rank: int, max_lora_rank: int,
lora_weight_names: Tuple[Set[str], Set[str]], lora_weight_names: Set[str],
base_model: torch.nn.Module, base_model: torch.nn.Module,
): ):
self.base_hf_config: AutoConfig = base_hf_config self.base_hf_config: AutoConfig = base_hf_config
...@@ -62,9 +62,7 @@ class LoRAMemoryPool: ...@@ -62,9 +62,7 @@ class LoRAMemoryPool:
self.tp_size: int = tp_size self.tp_size: int = tp_size
self.tp_rank: int = tp_rank self.tp_rank: int = tp_rank
self.max_lora_rank: int = max_lora_rank self.max_lora_rank: int = max_lora_rank
self.lora_weight_names: Set[str] = lora_weight_names
# lora weight names for LoRA A and B respectively.
self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names
# Both A_buffer and B_buffer maps lora weight names to its buffer space. # Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape # A_buffer contains num_layer number of row-major tensors with shape
...@@ -97,12 +95,8 @@ class LoRAMemoryPool: ...@@ -97,12 +95,8 @@ class LoRAMemoryPool:
""" """
if config.r > self.max_lora_rank: if config.r > self.max_lora_rank:
return False return False
weights_a, weights_b = get_normalized_lora_weight_names( weights = get_normalized_lora_weight_names(config.target_modules)
config.target_modules return weights.issubset(self.lora_weight_names)
)
return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset(
self.lora_weight_names[1]
)
if isinstance(config, LoRAConfig): if isinstance(config, LoRAConfig):
return _can_support(config) return _can_support(config)
...@@ -132,11 +126,9 @@ class LoRAMemoryPool: ...@@ -132,11 +126,9 @@ class LoRAMemoryPool:
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output. Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
""" """
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model) _, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
c = get_stacked_multiply(module_name)
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES: if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
output_dim = divide(output_dim, self.tp_size) output_dim = divide(output_dim, self.tp_size)
return ( return (
c,
self.max_loras_per_batch, self.max_loras_per_batch,
output_dim, output_dim,
max_lora_dim, max_lora_dim,
...@@ -165,13 +157,13 @@ class LoRAMemoryPool: ...@@ -165,13 +157,13 @@ class LoRAMemoryPool:
init_buffer( init_buffer(
self.A_buffer, self.A_buffer,
self.lora_weight_names[0], self.lora_weight_names,
self.get_lora_A_shape, self.get_lora_A_shape,
) )
init_buffer( init_buffer(
self.B_buffer, self.B_buffer,
self.lora_weight_names[1], self.lora_weight_names,
self.get_lora_B_shape, self.get_lora_B_shape,
) )
...@@ -246,7 +238,7 @@ class LoRAMemoryPool: ...@@ -246,7 +238,7 @@ class LoRAMemoryPool:
return return
assert lora_adapter is not None assert lora_adapter is not None
lora_rank = lora_adapter.config.hf_config["r"] lora_rank = lora_adapter.config.r
for layer_id in range(self.num_layer): for layer_id in range(self.num_layer):
layer_weights = lora_adapter.layers[layer_id].weights layer_weights = lora_adapter.layers[layer_id].weights
temp_A_buffer: Dict[str, Optional[torch.Tensor]] = { temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
...@@ -256,73 +248,38 @@ class LoRAMemoryPool: ...@@ -256,73 +248,38 @@ class LoRAMemoryPool:
weight_name: None for weight_name in self.B_buffer weight_name: None for weight_name in self.B_buffer
} }
for name, weights in layer_weights.items(): for name, weights in layer_weights.items():
lora_weight_name = get_weight_name(name, self.lora_weight_names)
if "lora_A" in name: if "lora_A" in name:
lora_weight_name = get_weight_name(
name, self.lora_weight_names, LoRAType.LORA_A
)
temp_A_buffer[lora_weight_name] = weights temp_A_buffer[lora_weight_name] = weights
else: else:
lora_weight_name = get_weight_name(
name, self.lora_weight_names, LoRAType.LORA_B
)
temp_B_buffer[lora_weight_name] = weights temp_B_buffer[lora_weight_name] = weights
if self.tp_size > 1: if self.tp_size > 1:
cur_layer_modules = lora_modules[layer_id] cur_layer_modules = lora_modules[layer_id]
for module_name, module in cur_layer_modules.items(): for module_name, module in cur_layer_modules.items():
weight_name = get_weight_name( weight_name = get_weight_name(module_name, self.lora_weight_names)
module_name, self.lora_weight_names, LoRAType.LORA_A
)
if temp_A_buffer[weight_name] is None: if temp_A_buffer[weight_name] is None:
# Skip weight slicing if the weight is not present in the adapter # Skip weight slicing if the weight is not present in the adapter
continue continue
if "qkv_proj" in module_name: temp_A_buffer[weight_name] = module.slice_lora_a_weights(
temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights( temp_A_buffer[weight_name], self.tp_rank
temp_A_buffer["qkv_proj"], self.tp_rank )
) temp_B_buffer[weight_name] = module.slice_lora_b_weights(
temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"] = ( temp_B_buffer[weight_name], self.tp_rank
module.slice_lora_b_weights( )
[temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"]],
self.tp_rank,
)
)
else:
# TODO (lifuhuang): Ideally, we should call `get_weight_name` separately for both A and B.
# Currently, we're reusing A's weight name as a workaround, relying on the fact that A and
# B share the same name except for `qkv_proj`. We should clean this up once we deprecate the
# FlashInfer LoRA backend.
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
temp_A_buffer[weight_name], self.tp_rank
)
temp_B_buffer[weight_name] = module.slice_lora_b_weights(
temp_B_buffer[weight_name], self.tp_rank
)
for name, weights in temp_A_buffer.items(): for name, weights in temp_A_buffer.items():
c = get_stacked_multiply(name) c = get_stacked_multiply(name)
buffer_view = self.A_buffer[name][layer_id][buffer_id][ target_buffer = self.A_buffer[name][layer_id]
: lora_rank * c, : buffer_view = target_buffer[buffer_id, : lora_rank * c, :]
]
load_lora_weight_tensor(buffer_view, weights) load_lora_weight_tensor(buffer_view, weights)
for name, weights in temp_B_buffer.items(): for name, weights in temp_B_buffer.items():
c = get_stacked_multiply(name) target_buffer = self.B_buffer[name][layer_id]
if c > 1: buffer_view = target_buffer[buffer_id, :, :lora_rank]
for stacked_id in range(c): load_lora_weight_tensor(buffer_view, weights)
buffer_view = self.B_buffer[name][layer_id][stacked_id][
buffer_id
][:, :lora_rank]
weight_slice = (
weights[stacked_id] if weights is not None else None
)
load_lora_weight_tensor(buffer_view, weight_slice)
else:
buffer_view = self.B_buffer[name][layer_id][0][buffer_id][
:, :lora_rank
]
load_lora_weight_tensor(buffer_view, weights)
def get_tensor( def get_tensor(
self, weight_name: str, layer_id: int, lora_type: LoRAType self, weight_name: str, layer_id: int, lora_type: LoRAType
......
...@@ -119,7 +119,7 @@ def _qkv_lora_b_kernel( ...@@ -119,7 +119,7 @@ def _qkv_lora_b_kernel(
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + ( output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
) )
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size) output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < n_size)
partial_sum += tl.load(output_ptr, mask=output_mask) partial_sum += tl.load(output_ptr, mask=output_mask)
tl.store(output_ptr, partial_sum, mask=output_mask) tl.store(output_ptr, partial_sum, mask=output_mask)
......
...@@ -47,34 +47,6 @@ def get_layer_id(name: str) -> int: ...@@ -47,34 +47,6 @@ def get_layer_id(name: str) -> int:
return int(match.group(1)) return int(match.group(1))
def get_customized_names_from_hf_names(
hf_module_names: Set[str], base_model: torch.nn.Module
) -> Set[str]:
"""
This function takes in a set of huggingface style module names:
e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
and outputs a set of module names of customized sglang layers:
e.g., {"qkv_proj", "o_proj"}
"""
if hasattr(base_model, "get_module_name"):
return {base_model.get_module_name(name) for name in hf_module_names}
else:
"""
Fallback solution of mapping from config module name to module name in model class.
Please check if it aligns with your base model.
Please implement the function in the model class if it is not.
You can reference this function in llama.py.
"""
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return {params_mapping.get(name, name) for name in hf_module_names}
def get_hidden_dim( def get_hidden_dim(
module_name: str, config: AutoConfig, base_model: torch.nn.Module module_name: str, config: AutoConfig, base_model: torch.nn.Module
) -> Tuple[int]: ) -> Tuple[int]:
...@@ -95,22 +67,9 @@ def get_hidden_dim( ...@@ -95,22 +67,9 @@ def get_hidden_dim(
head_dim = getattr( head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads config, "head_dim", config.hidden_size // config.num_attention_heads
) )
# TODO: the special handling of qkv will be addressed in #8940.
if module_name == "qkv_proj": if module_name == "qkv_proj":
return ( return config.hidden_size, head_dim * (
config.hidden_size, config.num_attention_heads + config.num_key_value_heads * 2
None, # qkv_proj is only used in LoRA A
)
elif module_name == "kv_proj":
return (
None, # kv_proj is only used in LoRA B
head_dim * config.num_key_value_heads,
)
elif module_name == "q_proj":
return (
None, # q_proj is only used in LoRA B
head_dim * config.num_attention_heads,
) )
elif module_name == "o_proj": elif module_name == "o_proj":
return ( return (
...@@ -118,7 +77,7 @@ def get_hidden_dim( ...@@ -118,7 +77,7 @@ def get_hidden_dim(
config.hidden_size, config.hidden_size,
) )
elif module_name == "gate_up_proj": elif module_name == "gate_up_proj":
return config.hidden_size, config.intermediate_size return config.hidden_size, config.intermediate_size * 2
elif module_name == "down_proj": elif module_name == "down_proj":
return config.intermediate_size, config.hidden_size return config.intermediate_size, config.hidden_size
else: else:
...@@ -127,26 +86,22 @@ def get_hidden_dim( ...@@ -127,26 +86,22 @@ def get_hidden_dim(
def get_normalized_lora_weight_names( def get_normalized_lora_weight_names(
target_modules: Iterable[str], target_modules: Iterable[str],
) -> Tuple[set[str], set[str]]: ) -> set[str]:
""" """
Mapping a list of target module name to names of the normalized LoRA weights. Mapping a list of target module name to names of the normalized LoRA weights.
Returned tuple contains (name for Lora A, name for Lora B)
""" """
params_mapping = { params_mapping = {
"q_proj": (["qkv_proj"], ["q_proj"]), "q_proj": "qkv_proj",
"k_proj": (["qkv_proj"], ["kv_proj"]), "k_proj": "qkv_proj",
"v_proj": (["qkv_proj"], ["kv_proj"]), "v_proj": "qkv_proj",
"gate_proj": (["gate_up_proj"], ["gate_up_proj"]), "gate_proj": "gate_up_proj",
"up_proj": (["gate_up_proj"], ["gate_up_proj"]), "up_proj": "gate_up_proj",
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
} }
result = (set(), set()) result = set()
for name in target_modules: for name in target_modules:
lora_a, lora_b = params_mapping.get(name, ([name], [name])) weight_name = params_mapping.get(name, name)
result[0].update(lora_a) result.add(weight_name)
result[1].update(lora_b)
return result return result
...@@ -156,23 +111,21 @@ def get_stacked_multiply(module_name: str) -> int: ...@@ -156,23 +111,21 @@ def get_stacked_multiply(module_name: str) -> int:
""" """
stacked_rank = { stacked_rank = {
"qkv_proj": 3, "qkv_proj": 3,
"kv_proj": 2,
"gate_up_proj": 2, "gate_up_proj": 2,
} }
return stacked_rank[module_name] if module_name in stacked_rank else 1 return stacked_rank[module_name] if module_name in stacked_rank else 1
def get_weight_name( def get_weight_name(
target_name: str, lora_weight_names: Tuple[Set[str]], lora_type: LoRAType target_name: str, lora_weight_names: Tuple[Set[str]]
) -> Optional[str]: ) -> Optional[str]:
""" """
target_name is name of a given module, Get the weight name in lora_weight_names that can match target_name.
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
If there is a weight name in lora_weight_names that can match target_name, return this name If there is a weight name in lora_weight_names that can match target_name, return this name
Else raise ValueError. Else raise ValueError.
""" """
idx = 0 if lora_type == LoRAType.LORA_A else 1 for weight_name in lora_weight_names:
for weight_name in lora_weight_names[idx]:
if weight_name in target_name: if weight_name in target_name:
return weight_name return weight_name
raise ValueError( raise ValueError(
...@@ -180,9 +133,4 @@ def get_weight_name( ...@@ -180,9 +133,4 @@ def get_weight_name(
) )
# TODO: [PR #4274] For future use to simplify the mapping between HF module names and customized module names.
VOCAB_PARALLELISM_EMBEDDING_NAMES = ["embeddings"]
COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_proj", "up_proj"]
MERGED_COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_up_proj"]
QKV_PARALLELISM_LINEAR_LORA_NAMES = ["qkv_proj"]
ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"] ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]
...@@ -501,23 +501,16 @@ class Gemma3nForConditionalGeneration(PreTrainedModel): ...@@ -501,23 +501,16 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
def get_hidden_dim(self, module_name): def get_hidden_dim(self, module_name):
# return input_dim, output_dim # return input_dim, output_dim
# TODO: the special handling of qkv will be addressed in #8940.
if module_name == "qkv_proj": if module_name == "qkv_proj":
return ( return (
self.config.hidden_size, self.config.hidden_size,
None, # qkv_proj is only used in LoRA A self.config.head_dim
* (
self.config.num_attention_heads
+ self.config.num_key_value_heads * 2
),
) )
elif module_name == "kv_proj": elif module_name == "o_proj":
return (
None, # kv_proj is only used in LoRA B
self.config.head_dim * self.config.num_key_value_heads,
)
elif module_name == "q_proj":
return (
None, # q_proj is only used in LoRA B
self.config.head_dim * self.config.num_attention_heads,
)
elif module_name in ["o_proj"]:
return ( return (
self.config.head_dim * self.config.num_attention_heads, self.config.head_dim * self.config.num_attention_heads,
self.config.hidden_size, self.config.hidden_size,
...@@ -527,7 +520,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel): ...@@ -527,7 +520,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
"Currently SGLang requires uniform intermediate size for all layers. " "Currently SGLang requires uniform intermediate size for all layers. "
"Please file an issue if you need support for non-uniform intermediate sizes." "Please file an issue if you need support for non-uniform intermediate sizes."
) )
return self.config.hidden_size, self.config.intermediate_size[0] return self.config.hidden_size, self.config.intermediate_size[0] * 2
elif module_name == "down_proj": elif module_name == "down_proj":
assert len(set(self.config.intermediate_size)) == 1, ( assert len(set(self.config.intermediate_size)) == 1, (
"Currently SGLang requires uniform intermediate size for all layers. " "Currently SGLang requires uniform intermediate size for all layers. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment