Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
84fd7c1d
Unverified
Commit
84fd7c1d
authored
Mar 18, 2022
by
HELSON
Committed by
GitHub
Mar 18, 2022
Browse files
add moe context, moe utilities and refactor gradient handler (#455)
parent
af185b55
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
255 additions
and
125 deletions
+255
-125
colossalai/context/__init__.py
colossalai/context/__init__.py
+1
-0
colossalai/context/moe_context.py
colossalai/context/moe_context.py
+151
-0
colossalai/context/parallel_context.py
colossalai/context/parallel_context.py
+0
-8
colossalai/context/random/_helper.py
colossalai/context/random/_helper.py
+3
-8
colossalai/core.py
colossalai/core.py
+2
-1
colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py
...ngine/gradient_handler/_data_parallel_gradient_handler.py
+2
-28
colossalai/engine/gradient_handler/_moe_gradient_handler.py
colossalai/engine/gradient_handler/_moe_gradient_handler.py
+10
-37
colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py
...e/gradient_handler/_sequence_parallel_gradient_handler.py
+3
-33
colossalai/engine/gradient_handler/utils.py
colossalai/engine/gradient_handler/utils.py
+29
-0
colossalai/utils/moe.py
colossalai/utils/moe.py
+51
-0
tests/test_moe/short_test.py
tests/test_moe/short_test.py
+3
-10
No files found.
colossalai/context/__init__.py
View file @
84fd7c1d
from
.config
import
Config
,
ConfigException
from
.parallel_context
import
ParallelContext
from
.moe_context
import
MoeContext
from
.parallel_mode
import
ParallelMode
from
.process_group_initializer
import
*
from
.random
import
*
colossalai/context/moe_context.py
0 → 100644
View file @
84fd7c1d
import
torch
import
torch.distributed
as
dist
from
.parallel_mode
import
ParallelMode
def
_check_sanity
():
from
colossalai.core
import
global_context
as
gpc
if
gpc
.
tensor_parallel_size
>
1
or
gpc
.
pipeline_parallel_size
>
1
:
raise
NotImplementedError
(
"Moe is not compatible with tensor or "
"pipeline parallel at present."
)
class
MoeInfo
:
"""Moe parallelism information, storing parallel sizes and groups.
"""
def
__init__
(
self
,
ep_size
:
int
,
dp_size
:
int
):
_check_sanity
()
self
.
ep_size
=
ep_size
self
.
dp_size
=
dp_size
self
.
ep_group
=
None
# data parallel group for experts, since ep_group is different
# we may have different dp_group from get_group(ParallelMode.DATA)
self
.
dp_group
=
None
# Here we assume tensor parallel size = 1
# Otherwise, MoE can't be used
# Since TENSOR parallel group and DATA parallel group
# have been created, we can use them directly.
if
ep_size
==
1
:
from
colossalai.core
import
global_context
as
gpc
self
.
ep_group
=
gpc
.
get_group
(
ParallelMode
.
TENSOR
)
self
.
dp_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
)
return
if
dp_size
==
1
:
from
colossalai.core
import
global_context
as
gpc
self
.
ep_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
dp_group
=
gpc
.
get_group
(
ParallelMode
.
TENSOR
)
return
rank
=
dist
.
get_rank
()
# Create expert parallel group
for
i
in
range
(
dp_size
):
ranks
=
[
i
*
ep_size
+
j
for
j
in
range
(
ep_size
)]
group
=
dist
.
new_group
(
ranks
)
if
rank
in
ranks
:
self
.
ep_group
=
group
# Create data parallel group
for
j
in
range
(
ep_size
):
ranks
=
[
i
*
ep_size
+
j
for
i
in
range
(
dp_size
)]
group
=
dist
.
new_group
(
ranks
)
if
rank
in
ranks
:
self
.
dp_group
=
group
class
MoeContext
:
"""MoE parallel context manager. This class manages different
parallel groups in MoE context and MoE loss in training.
"""
__instance
=
None
@
staticmethod
def
get_instance
():
if
MoeContext
.
__instance
is
None
:
MoeContext
.
__instance
=
MoeContext
()
return
MoeContext
.
__instance
def
__init__
(
self
):
self
.
world_size
=
1
# Users may want to set maximum expert parallel size smaller than the world size
# since very low bandwidth across nodes may constrain the performance of MoE
# When we have a maximum expert parallel size, we have a minimum data parallel size naturally
self
.
max_ep_size
=
1
self
.
min_dp_size
=
1
self
.
aux_loss
=
None
self
.
use_kernel_optim
=
True
self
.
has_setup
=
False
self
.
_info_dict
=
dict
()
@
property
def
information
(
self
):
return
self
.
_info_dict
@
property
def
is_initialized
(
self
):
return
self
.
has_setup
def
setup
(
self
,
seed
:
int
,
use_kernel_optim
:
bool
=
True
):
assert
not
self
.
is_initialized
,
"MoE distributed context shouldn't be set up again"
_check_sanity
()
assert
torch
.
cuda
.
is_available
(),
"MoE requires to enable CUDA first"
self
.
world_size
=
dist
.
get_world_size
()
from
colossalai.core
import
global_context
as
gpc
self
.
max_ep_size
=
gpc
.
config
.
get
(
'max_ep_size'
,
self
.
world_size
)
assert
self
.
world_size
%
self
.
max_ep_size
==
0
,
\
"Maximum epxert parallel size must be a factor of the number of GPUs"
self
.
min_dp_size
=
self
.
world_size
//
self
.
max_ep_size
# Enabling kernel optimization may raise error in some cases
# Users can close kernel optimization manually
self
.
use_kernel_optim
=
use_kernel_optim
from
.random
import
moe_set_seed
moe_set_seed
(
seed
)
self
.
has_setup
=
True
def
get_info
(
self
,
num_experts
:
int
):
"""Automatically deploys experts and returns parallel infomation about
distributed communication groups.
"""
gt_flag
=
num_experts
%
self
.
max_ep_size
==
0
# check whether num_experts is greater
lt_flag
=
self
.
max_ep_size
%
num_experts
==
0
# check whether num_experts is less
assert
gt_flag
or
lt_flag
,
"Automatic experts placement do not support such situation right now."
# If the number of experts is greater than maximum expert parallel size,
# there are multiple experts in each GPU and each GPU has different experts
# So it's data parallel size is 1
# Otherwise, there is only one expert in each GPU
# The data parallel size should be calculated
dp_size
=
1
if
gt_flag
else
self
.
max_ep_size
//
num_experts
ep_size
=
self
.
max_ep_size
//
dp_size
# Calculate the number of experts for each GPU
num_local_experts
=
1
if
lt_flag
else
num_experts
//
self
.
max_ep_size
# Don't forget to multiply minimum data parallel size
dp_size
*=
self
.
min_dp_size
if
not
(
ep_size
in
self
.
information
):
self
.
information
[
ep_size
]
=
MoeInfo
(
ep_size
,
dp_size
)
return
num_local_experts
,
self
.
information
[
ep_size
]
def
set_kernel_not_use
(
self
):
self
.
use_kernel_optim
=
False
def
reset_loss
(
self
):
self
.
aux_loss
=
0
def
add_loss
(
self
,
loss
):
self
.
aux_loss
+=
loss
def
get_loss
(
self
):
return
self
.
aux_loss
colossalai/context/parallel_context.py
View file @
84fd7c1d
...
...
@@ -9,7 +9,6 @@ import torch
import
torch.distributed
as
dist
from
colossalai.constants
import
ALLOWED_MODES
,
INITIALIZER_MAPPING
from
colossalai.context.config
import
Config
from
colossalai.global_variables
import
moe_env
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.logging
import
get_dist_logger
from
colossalai.registry
import
DIST_GROUP_INITIALIZER
...
...
@@ -407,13 +406,6 @@ class ParallelContext:
# add this config to initialize later
pg_init
.
append
(
dict
(
type
=
INITIALIZER_MAPPING
[
tensor_parallel_mode
.
lower
()],
**
tensor_parallel_cfg
))
# initialization for moe environment
if
parallel_config
is
not
None
and
'moe'
in
parallel_config
:
param
=
parallel_config
[
'moe'
]
assert
'size'
in
param
,
"Moe model parallel size should be given"
moe_env
.
setup
(
param
[
'size'
])
pg_init
.
append
(
dict
(
type
=
INITIALIZER_MAPPING
[
'moe'
]))
# run initialization of different process groups
for
initializer_cfg
in
pg_init
:
cfg
=
initializer_cfg
.
copy
()
...
...
colossalai/context/random/_helper.py
View file @
84fd7c1d
...
...
@@ -147,15 +147,10 @@ def with_seed(func, parallel_mode: ParallelMode):
def
moe_set_seed
(
seed
):
if
torch
.
cuda
.
is_available
():
from
colossalai.core
import
global_context
as
gpc
moe_mp_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
MOE_MODEL
)
moe_mp_seed
=
seed
+
moe_mp_rank
add_seed
(
ParallelMode
.
MOE_MODEL
,
moe_mp_seed
)
global_rank
=
gpc
.
get_global_rank
()
add_seed
(
ParallelMode
.
TENSOR
,
global_rank
,
True
)
print
(
f
"moe seed condition:
{
global_rank
}
with moe seed
{
moe_mp_seed
}
, "
,
f
"tensor seed
{
global_rank
}
"
,
flush
=
True
)
diff_seed
=
seed
+
global_rank
add_seed
(
ParallelMode
.
TENSOR
,
diff_seed
,
True
)
print
(
f
"moe seed condition:
{
global_rank
}
with tensor seed
{
diff_seed
}
"
,
flush
=
True
)
def
reset_seeds
():
...
...
colossalai/core.py
View file @
84fd7c1d
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
colossalai.context
import
ParallelContext
from
colossalai.context
import
ParallelContext
,
MoeContext
global_context
=
ParallelContext
.
get_instance
()
moe_context
=
MoeContext
.
get_instance
()
colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py
View file @
84fd7c1d
#!/usr/bin/env python
import
torch.distributed
as
dist
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
colossalai.core
import
global_context
as
gpc
from
colossalai.registry
import
GRADIENT_HANDLER
from
._base_gradient_handler
import
BaseGradientHandler
from
...context.parallel_mode
import
ParallelMode
from
.utils
import
bucket_allreduce
@
GRADIENT_HANDLER
.
register_module
...
...
@@ -23,26 +19,4 @@ class DataParallelGradientHandler(BaseGradientHandler):
"""
# TODO: add memory buffer
if
gpc
.
data_parallel_size
>
1
:
# bucketize and all-reduce
buckets
=
{}
# Pack the buckets.
for
param
in
self
.
_model
.
parameters
():
if
param
.
requires_grad
and
param
.
grad
is
not
None
:
tp
=
param
.
data
.
type
()
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
param
)
# param.main_grad = param.grad
# For each bucket, all-reduce and copy all-reduced grads.
for
tp
in
buckets
:
bucket
=
buckets
[
tp
]
grads
=
[
param
.
grad
.
data
for
param
in
bucket
]
coalesced
=
_flatten_dense_tensors
(
grads
)
coalesced
/=
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
dist
.
all_reduce
(
coalesced
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
bucket_allreduce
(
param_list
=
self
.
_model
.
parameters
(),
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
colossalai/engine/gradient_handler/_moe_gradient_handler.py
View file @
84fd7c1d
import
torch.distributed
as
dist
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
,
moe_context
as
moe_env
from
colossalai.registry
import
GRADIENT_HANDLER
from
colossalai.
global_variables
import
moe_e
nv
from
colossalai.
utils.moe
import
get_
moe_e
psize_param_dict
from
._base_gradient_handler
import
BaseGradientHandler
from
...context.parallel_mode
import
ParallelMode
from
.utils
import
bucket_allreduce
@
GRADIENT_HANDLER
.
register_module
...
...
@@ -21,41 +20,15 @@ class MoeGradientHandler(BaseGradientHandler):
Then running an all-reduce operation for all parameters in experts
across moe model parallel group
"""
moe_data
=
moe_env
.
data_parallel_size
global_data
=
gpc
.
data_parallel_size
if
global_data
>
1
:
# bucketize and all-reduce
buckets
=
{}
# Pack the buckets.
for
param
in
self
.
_model
.
parameters
():
if
param
.
requires_grad
and
\
param
.
grad
is
not
None
and
\
not
hasattr
(
param
,
'moe_param'
):
tp
=
param
.
data
.
type
()
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
param
)
# param.main_grad = param.grad
param_dict
=
get_moe_epsize_param_dict
(
self
.
_model
)
# For each bucket, all-reduce and copy all-reduced grads.
for
tp
in
buckets
:
bucket
=
buckets
[
tp
]
grads
=
[
param
.
grad
.
data
for
param
in
bucket
]
coalesced
=
_flatten_dense_tensors
(
grads
)
coalesced
/=
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
# reduce gradients for all parameters in data parallelism
if
1
in
param_dict
:
bucket_allreduce
(
param_list
=
param_dict
[
1
],
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
dist
.
all_reduce
(
coalesced
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
if
global_data
>
1
:
for
param
in
self
.
_model
.
parameters
():
if
not
param
.
requires_grad
or
param
.
grad
is
None
:
continue
if
moe_data
>
1
and
hasattr
(
param
,
'moe_param'
):
param
.
grad
.
data
/=
moe_data
dist
.
all_reduce
(
param
.
grad
.
data
,
group
=
gpc
.
get_group
(
ParallelMode
.
MOE_DATA
))
for
ep_size
in
param_dict
:
if
ep_size
!=
1
and
ep_size
!=
moe_env
.
world_size
:
bucket_allreduce
(
param_list
=
param_dict
[
ep_size
],
group
=
moe_env
.
information
[
ep_size
].
dp_group
)
colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py
View file @
84fd7c1d
#!/usr/bin/env python
from
functools
import
total_ordering
import
torch
import
torch.distributed
as
dist
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
colossalai.core
import
global_context
as
gpc
from
colossalai.registry
import
GRADIENT_HANDLER
from
._base_gradient_handler
import
BaseGradientHandler
from
...context.parallel_mode
import
ParallelMode
import
colossalai
from
.utils
import
bucket_allreduce
@
GRADIENT_HANDLER
.
register_module
...
...
@@ -23,29 +17,5 @@ class SequenceParallelGradientHandler(BaseGradientHandler):
def
handle_gradient
(
self
):
"""A method running a all-reduce operation in a data parallel group.
"""
# bucketize and all-reduce
buckets
=
{}
# Pack the buckets.
for
param
in
self
.
_model
.
parameters
():
if
param
.
requires_grad
and
param
.
grad
is
not
None
:
tp
=
param
.
data
.
type
()
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
param
)
# For each bucket, all-reduce and copy all-reduced grads.
for
tp
in
buckets
:
bucket
=
buckets
[
tp
]
grads
=
[
param
.
grad
.
data
for
param
in
bucket
]
coalesced
=
_flatten_dense_tensors
(
grads
)
coalesced
/=
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE_DP
)
dist
.
all_reduce
(
coalesced
,
group
=
gpc
.
get_group
(
ParallelMode
.
SEQUENCE_DP
))
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
if
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE_DP
)
>
1
:
bucket_allreduce
(
param_list
=
self
.
_model
.
parameters
(),
group
=
gpc
.
get_group
(
ParallelMode
.
SEQUENCE_DP
))
colossalai/engine/gradient_handler/utils.py
0 → 100644
View file @
84fd7c1d
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
typing
import
Iterable
def
bucket_allreduce
(
param_list
:
Iterable
[
nn
.
Parameter
],
group
=
None
):
# get communication world size
comm_size
=
dist
.
get_world_size
(
group
)
# bucketize and all-reduce
buckets
=
{}
# Pack the buckets.
for
param
in
param_list
:
if
param
.
requires_grad
and
param
.
grad
is
not
None
:
tp
=
param
.
data
.
type
()
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
param
)
# For each bucket, all-reduce and copy all-reduced grads.
for
tp
in
buckets
:
bucket
=
buckets
[
tp
]
grads
=
[
param
.
grad
.
data
for
param
in
bucket
]
coalesced
=
_flatten_dense_tensors
(
grads
)
coalesced
/=
comm_size
dist
.
all_reduce
(
coalesced
,
group
=
group
)
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
colossalai/utils/moe.py
0 → 100644
View file @
84fd7c1d
import
torch.nn
as
nn
import
torch.distributed
as
dist
from
colossalai.core
import
global_context
as
gpc
,
moe_context
as
moe_env
from
colossalai.context
import
ParallelMode
from
.common
import
is_using_ddp
from
typing
import
Dict
,
List
def
get_moe_epsize_param_dict
(
model
:
nn
.
Module
)
->
Dict
[
int
,
List
[
nn
.
Parameter
]]:
"""Returns a parameter dictionary, the key of which is the expert parallel
size of every parameter. Since the parameters in data parallelism is replicated
in each GPU, we set their ep_size to 1.
:param model: A pyTorch nn.model from which we get dict
:type model: torch.nn.Module
"""
epsize_param_dict
=
dict
()
for
param
in
model
.
parameters
():
if
not
hasattr
(
param
,
'moe_info'
):
ep_size
=
1
# set ep_size to 1 for dp parameters
else
:
ep_size
=
param
.
moe_info
.
ep_size
if
ep_size
not
in
epsize_param_dict
:
epsize_param_dict
[
ep_size
]
=
[]
epsize_param_dict
[
ep_size
].
append
(
param
)
return
epsize_param_dict
def
sync_moe_model_param
(
model
:
nn
.
Module
):
"""Make sure model parameters are consistent in MoE parallel context
:param model: A pyTorch nn.model on whose parameters you check the consistency
:type model: torch.nn.Module
"""
if
is_using_ddp
():
param_dict
=
get_moe_epsize_param_dict
(
model
)
# synchrosize the parameters whose dp_group is the whole world
if
1
in
param_dict
:
src_rank
=
gpc
.
get_ranks_in_group
(
ParallelMode
.
DATA
)[
0
]
for
param
in
param_dict
[
1
]:
dist
.
broadcast
(
param
,
src
=
src_rank
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
for
ep_size
in
param_dict
:
# When ep_size = world_size, communication is not needed
if
ep_size
!=
1
and
ep_size
!=
moe_env
.
world_size
:
src_rank
=
dist
.
get_rank
(
moe_env
.
information
[
ep_size
].
ep_group
)
for
param
in
param_dict
[
ep_size
]:
dist
.
broadcast
(
param
,
src
=
src_rank
,
group
=
param
.
moe_info
.
dp_group
)
tests/test_moe/short_test.py
View file @
84fd7c1d
...
...
@@ -23,13 +23,13 @@ def check_equal(A, B, atol=1e-06):
def
run_routing
(
rank
,
world_size
,
port
,
rs
=
2
,
hidden_size
=
128
,
data_type
=
torch
.
float32
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
moe_set_seed
(
42
)
# torch.set_printoptions(precision=30)
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
GLOBAL
)
torch
.
manual_seed
(
rs
+
local_rank
)
moe_env
.
reset_loss
()
tokens
=
torch
.
randn
(
BATCH_SIZE
,
hidden_size
,
dtype
=
data_type
,
device
=
get_current_device
(),
requires_grad
=
True
)
# print(f"tokens:\n{tokens}")
router
=
Top2Router
(
1
)
expert
=
Experts
(
nn
.
Identity
,
4
)
layer
=
MoeLayer
(
hidden_size
,
NUM_EXPERTS
,
router
,
expert
)
...
...
@@ -38,7 +38,6 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
layer
.
cuda_mode
=
False
old_out
=
layer
(
tokens
)
# print(f"old output:\n{old_out}")
ech
=
old_out
.
shape
grad
=
torch
.
randn
(
ech
,
device
=
get_current_device
())
...
...
@@ -53,33 +52,27 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
layer
.
cuda_mode
=
True
new_out
=
layer
(
tokens
)
# print(torch.max(torch.abs(old_out - new_out)))
if
data_type
==
torch
.
float32
:
check_equal
(
old_out
,
new_out
)
else
:
check_equal
(
old_out
,
new_out
,
1e-2
)
# print(f"forward functions passed")
# print(f"new output:\n{new_out}")
new_out
.
backward
(
grad
)
n_tk_grad
=
tokens
.
grad
.
data
.
clone
()
n_gt_grad
=
layer
.
gate
.
weight
.
grad
.
data
.
clone
()
# print(torch.max(torch.abs(o_tk_grad - n_tk_grad)))
if
data_type
==
torch
.
float32
:
check_equal
(
o_tk_grad
,
n_tk_grad
)
else
:
check_equal
(
o_tk_grad
,
o_tk_grad
,
1e-2
)
# print(f"tokens gradient passed")
# print(torch.max(torch.abs(o_gt_grad - n_gt_grad)))
if
data_type
==
torch
.
float32
:
check_equal
(
o_gt_grad
,
n_gt_grad
,
5e-05
)
else
:
check_equal
(
o_gt_grad
,
n_gt_grad
,
2e-01
)
# print(f"linear weight gradient passed")
@
pytest
.
mark
.
skip
(
reason
=
"MoE refactoring has not finished yet"
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"rs"
,
[
131
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
32
,
144
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment