Commit 5e766c85 authored by zhuwenwen's avatar zhuwenwen
Browse files

[Kernel] Prototype integration of flux kernels

parent feeb058b
...@@ -33,6 +33,7 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, ...@@ -33,6 +33,7 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union) Union)
from unittest.mock import patch from unittest.mock import patch
import flux
import torch import torch
import torch.distributed import torch.distributed
from torch.distributed import Backend, ProcessGroup from torch.distributed import Backend, ProcessGroup
...@@ -207,6 +208,10 @@ class GroupCoordinator: ...@@ -207,6 +208,10 @@ class GroupCoordinator:
self.use_hpu_communicator = use_hpu_communicator self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator self.use_xpu_communicator = use_xpu_communicator
# Initialize pynvshmem
if torch.distributed.get_world_size(self.device_group) > 1:
flux.init_flux_shm(self.device_group)
# lazy import to avoid documentation build error # lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import ( from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce) CustomAllreduce)
......
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
import itertools import itertools
from abc import abstractmethod from abc import abstractmethod
from typing import Optional from typing import Optional, List
import flux
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter from torch.nn.parameter import Parameter, UninitializedParameter
...@@ -13,6 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, ...@@ -13,6 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_tp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
...@@ -161,6 +163,144 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -161,6 +163,144 @@ class UnquantizedLinearMethod(LinearMethodBase):
return F.linear(x, layer.weight, bias) return F.linear(x, layer.weight, bias)
class GemmRS(LinearMethodBase):
#Fused Gemm-ReduceScatter without quantization.
def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
if self.use_llama_nn:
self.gemm_rs_op = flux.GemmRS(
get_tp_group().device_group,
nnodes=1, # One node
max_m=8192, # Max M. TODO: Pass in correctly.
n_dim=output_size, # N
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype=params_dtype, #torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=True,
# Note: bfloat16 requires fuse_reduction=False.
fuse_reduction=False,
)
else:
self.gemm_rs_op = flux.GemmRS(
get_tp_group().device_group,
nnodes=1, # One node
max_m=8192, # Max M. TODO: Pass in correctly.
n_dim=output_size, # N
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype=params_dtype, #torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=False,
# Note: bfloat16 requires fuse_reduction=False.
fuse_reduction=False,
)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert bias is None
output = self.gemm_rs_op.forward(x, layer.weight)
output = output.squeeze(0)
return output
class AGCook(LinearMethodBase):
#Fused AllGather-Gemm without quantization.
def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
if self.use_llama_nn:
self.ag_gemm_op = flux.AGKernel(
get_tp_group().device_group,
nnodes=1, # One node
full_m=8192, # Max M. TODO: Pass in correctly.
n_dim=weight.shape[0], # N
k_dim=weight.shape[1], # K
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype=params_dtype, #torch.float16,
output_dtype=params_dtype, #torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=True,
# Note: if local_copy=True, I hit the following runtime error:
# /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648
# Check failed: 33554432((input.numel() * input.element_size()))
# == 139836453421056((this->chunk_size))
local_copy=False,
)
else:
self.ag_gemm_op = flux.AGKernel(
get_tp_group().device_group,
nnodes=1, # One node
full_m=8192, # Max M. TODO: Pass in correctly.
n_dim=weight.shape[0], # N
k_dim=weight.shape[1], # K
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype=params_dtype, #torch.float16,
output_dtype=params_dtype, #torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=False,
# Note: if local_copy=True, I hit the following runtime error:
# /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648
# Check failed: 33554432((input.numel() * input.element_size()))
# == 139836453421056((this->chunk_size))
local_copy=False,
)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert bias is None
output = self.ag_gemm_op.forward(x, layer.weight)
return output
class LinearBase(torch.nn.Module): class LinearBase(torch.nn.Module):
"""Base linear layer. """Base linear layer.
...@@ -181,6 +321,8 @@ class LinearBase(torch.nn.Module): ...@@ -181,6 +321,8 @@ class LinearBase(torch.nn.Module):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
fuse_gemm_rs: bool = False,
fuse_ag_gemm: bool = False,
): ):
super().__init__() super().__init__()
...@@ -191,9 +333,14 @@ class LinearBase(torch.nn.Module): ...@@ -191,9 +333,14 @@ class LinearBase(torch.nn.Module):
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype self.params_dtype = params_dtype
if quant_config is None: if fuse_gemm_rs:
self.quant_method: Optional[ assert (quant_config is None)
QuantizeMethodBase] = UnquantizedLinearMethod() self.quant_method: Optional[QuantizeMethodBase] = GemmRS()
elif fuse_ag_gemm:
assert (quant_config is None)
self.quant_method = AGCook()
elif quant_config is None:
self.quant_method = UnquantizedLinearMethod()
else: else:
self.quant_method = quant_config.get_quant_method(self, self.quant_method = quant_config.get_quant_method(self,
prefix=prefix) prefix=prefix)
...@@ -308,9 +455,10 @@ class ColumnParallelLinear(LinearBase): ...@@ -308,9 +455,10 @@ class ColumnParallelLinear(LinearBase):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[list[int]] = None, output_sizes: Optional[list[int]] = None,
prefix: str = ""): prefix: str = "",
fuse_ag_gemm: bool = False):
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix) quant_config, prefix, fuse_ag_gemm=fuse_ag_gemm)
self.gather_output = gather_output self.gather_output = gather_output
...@@ -447,7 +595,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -447,7 +595,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = "",
fuse_ag_gemm: bool = False):
self.output_sizes = output_sizes self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes) assert all(output_size % tp_size == 0 for output_size in output_sizes)
...@@ -458,7 +607,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -458,7 +607,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix) prefix=prefix,
fuse_ag_gemm=fuse_ag_gemm)
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -723,7 +873,8 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -723,7 +873,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = "",
fuse_ag_gemm: bool = False):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = head_size self.head_size = head_size
self.total_num_heads = total_num_heads self.total_num_heads = total_num_heads
...@@ -756,7 +907,8 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -756,7 +907,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix) prefix=prefix,
fuse_ag_gemm=fuse_ag_gemm)
def _get_shard_offset_mapping(self, loaded_shard_id: str): def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = { shard_offset_mapping = {
...@@ -1060,12 +1212,15 @@ class RowParallelLinear(LinearBase): ...@@ -1060,12 +1212,15 @@ class RowParallelLinear(LinearBase):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True, reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = "",
fuse_gemm_rs: bool = False):
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix) quant_config, prefix, fuse_gemm_rs=fuse_gemm_rs)
self.input_is_parallel = input_is_parallel self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results self.reduce_results = reduce_results
if fuse_gemm_rs:
self.reduce_results = False
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
......
...@@ -33,7 +33,7 @@ import vllm.envs as envs ...@@ -33,7 +33,7 @@ import vllm.envs as envs
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, tensor_model_parallel_all_gather
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -71,6 +71,7 @@ class LlamaMLP(nn.Module): ...@@ -71,6 +71,7 @@ class LlamaMLP(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
prefix: str = "", prefix: str = "",
last_layer: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
...@@ -79,6 +80,7 @@ class LlamaMLP(nn.Module): ...@@ -79,6 +80,7 @@ class LlamaMLP(nn.Module):
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj", prefix=f"{prefix}.gate_up_proj",
fuse_ag_gemm=True,
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
input_size=intermediate_size, input_size=intermediate_size,
...@@ -86,6 +88,7 @@ class LlamaMLP(nn.Module): ...@@ -86,6 +88,7 @@ class LlamaMLP(nn.Module):
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
fuse_gemm_rs=(not last_layer),
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
...@@ -106,6 +109,7 @@ class LlamaAttention(nn.Module): ...@@ -106,6 +109,7 @@ class LlamaAttention(nn.Module):
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
first_layer: bool,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
...@@ -148,6 +152,7 @@ class LlamaAttention(nn.Module): ...@@ -148,6 +152,7 @@ class LlamaAttention(nn.Module):
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
fuse_ag_gemm=(not first_layer),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
...@@ -156,6 +161,7 @@ class LlamaAttention(nn.Module): ...@@ -156,6 +161,7 @@ class LlamaAttention(nn.Module):
bias=bias_o_proj, bias=bias_o_proj,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
fuse_gemm_rs=True,
) )
is_neox_style = True is_neox_style = True
...@@ -223,6 +229,11 @@ class LlamaDecoderLayer(nn.Module): ...@@ -223,6 +229,11 @@ class LlamaDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: LlamaConfig, config: LlamaConfig,
# Hack: pass in whether this is the first/last layer
# so we know if we can rewrite AllReduce -> ReduceScatter + AllGather,
# and then propagate the AllGather to the next layer.
first_layer: bool,
last_layer: bool,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
...@@ -252,6 +263,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -252,6 +263,7 @@ class LlamaDecoderLayer(nn.Module):
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads", num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads), config.num_attention_heads),
first_layer=first_layer,
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
...@@ -268,12 +280,16 @@ class LlamaDecoderLayer(nn.Module): ...@@ -268,12 +280,16 @@ class LlamaDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
bias=getattr(config, "mlp_bias", False), bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
last_layer=last_layer,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.first_layer = first_layer
self.last_layer = last_layer
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -287,17 +303,35 @@ class LlamaDecoderLayer(nn.Module): ...@@ -287,17 +303,35 @@ class LlamaDecoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
else: else:
assert (hidden_states.shape == residual.shape)
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
# Partition residual
if self.first_layer:
n_slices = get_tensor_model_parallel_world_size()
residual_slices = torch.chunk(residual, n_slices, dim=0)
my_residual = residual_slices[get_tensor_model_parallel_rank()]
else:
my_residual = residual
hidden_states = self.self_attn(positions=positions, hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
attn_metadata=attn_metadata) attn_metadata=attn_metadata)
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm( assert (hidden_states.shape == my_residual.shape)
hidden_states, residual) hidden_states, my_residual = self.post_attention_layernorm(
hidden_states, my_residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
if self.last_layer:
residual = tensor_model_parallel_all_gather(my_residual, 0)
else:
residual = my_residual
assert (hidden_states.shape == residual.shape)
return hidden_states, residual return hidden_states, residual
...@@ -335,7 +369,9 @@ class LlamaModel(nn.Module): ...@@ -335,7 +369,9 @@ class LlamaModel(nn.Module):
self.embed_tokens = PPMissingLayer() self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: layer_type(config=config, lambda prefix, first_layer, last_layer: layer_type(config=config,
first_layer=first_layer,
last_layer=last_layer,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix), prefix=prefix),
......
...@@ -548,14 +548,31 @@ def make_layers( ...@@ -548,14 +548,31 @@ def make_layers(
"""Make a list of layers with the given layer function, taking """Make a list of layers with the given layer function, taking
pipeline parallelism into account. pipeline parallelism into account.
""" """
import inspect
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.distributed.utils import get_pp_indices from vllm.distributed.utils import get_pp_indices
start_layer, end_layer = get_pp_indices(num_hidden_layers, start_layer, end_layer = get_pp_indices(num_hidden_layers,
get_pp_group().rank_in_group, get_pp_group().rank_in_group,
get_pp_group().world_size) get_pp_group().world_size)
# Determine if layer_fn accepts first/last args by inspecting its signature
sig = inspect.signature(layer_fn)
has_firstlast_args = ('first_layer'
in sig.parameters) and ('last_layer'
in sig.parameters)
def make_one_layer(idx, start_layer, end_layer):
if has_firstlast_args:
return maybe_offload_to_cpu(
layer_fn(prefix=f"{prefix}.{idx}",
first_layer=(idx == start_layer),
last_layer=(idx == end_layer - 1)))
else:
return maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
modules = torch.nn.ModuleList( modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] + [ [PPMissingLayer() for _ in range(start_layer)] + [
maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}")) make_one_layer(idx, start_layer, end_layer)
for idx in range(start_layer, end_layer) for idx in range(start_layer, end_layer)
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
return start_layer, end_layer, modules return start_layer, end_layer, modules
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment