Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1d65283e
Unverified
Commit
1d65283e
authored
Feb 17, 2026
by
Jiangyun Zhu
Committed by
GitHub
Feb 17, 2026
Browse files
Revert "[Models] Fuse Qwen3.5 GDN's qkvz_proj and ba_proj" (#34683)
parent
c464b573
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
182 additions
and
87 deletions
+182
-87
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+6
-28
vllm/model_executor/models/qwen3_5.py
vllm/model_executor/models/qwen3_5.py
+166
-32
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+10
-27
No files found.
vllm/model_executor/layers/linear.py
View file @
1d65283e
...
@@ -685,13 +685,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -685,13 +685,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self
,
self
,
param
:
Parameter
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
tuple
[
int
,
...]
|
int
|
None
=
None
,
loaded_shard_id
:
int
|
None
=
None
,
):
):
if
isinstance
(
loaded_shard_id
,
tuple
):
raise
NotImplementedError
(
"Shard id with multiple indices is not supported in weight_loader, "
"please use weight_loader_v2 instead."
)
# Special case for GGUF
# Special case for GGUF
# initialize GGUF param after we know the quantize type
# initialize GGUF param after we know the quantize type
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
...
@@ -830,10 +825,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -830,10 +825,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
def
_load_fused_module_from_checkpoint
(
def
_load_fused_module_from_checkpoint
(
self
,
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
,
output_sizes
:
list
[
int
]
|
None
=
None
,
):
):
"""
"""
Handle special case for models where MLP layers are already
Handle special case for models where MLP layers are already
...
@@ -847,8 +839,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -847,8 +839,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
current_shard_offset
=
0
current_shard_offset
=
0
shard_offsets
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
shard_offsets
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
output_sizes
=
output_sizes
or
self
.
output_sizes
for
i
,
output_size
in
enumerate
(
self
.
output_sizes
):
for
i
,
output_size
in
enumerate
(
output_sizes
):
shard_offsets
.
append
((
i
,
current_shard_offset
,
output_size
))
shard_offsets
.
append
((
i
,
current_shard_offset
,
output_size
))
current_shard_offset
+=
output_size
current_shard_offset
+=
output_size
...
@@ -873,30 +864,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -873,30 +864,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self
,
self
,
param
:
BasevLLMParameter
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
tuple
[
int
,
...]
|
int
|
None
=
None
,
loaded_shard_id
:
int
|
None
=
None
,
):
):
if
loaded_shard_id
is
None
or
isinstance
(
loaded_shard_id
,
tuple
)
:
if
loaded_shard_id
is
None
:
if
isinstance
(
param
,
PerTensorScaleParameter
):
if
isinstance
(
param
,
PerTensorScaleParameter
):
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
0
)
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
0
)
return
return
elif
type
(
param
)
in
(
RowvLLMParameter
,
BasevLLMParameter
):
elif
type
(
param
)
in
(
RowvLLMParameter
,
BasevLLMParameter
):
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
)
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
)
return
return
output_sizes
=
(
[
self
.
output_sizes
[
idx
]
for
idx
in
loaded_shard_id
]
if
loaded_shard_id
else
None
)
if
isinstance
(
param
,
BlockQuantScaleParameter
):
weight_block_size
=
getattr
(
self
,
"weight_block_size"
,
None
)
output_sizes
=
[
adjust_block_scale_shard
(
weight_block_size
,
size
,
0
)[
0
]
for
size
in
(
output_sizes
or
self
.
output_sizes
)
]
# TODO: @dsikka - move to parameter.py
# TODO: @dsikka - move to parameter.py
self
.
_load_fused_module_from_checkpoint
(
self
.
_load_fused_module_from_checkpoint
(
param
,
loaded_weight
)
param
,
loaded_weight
,
output_sizes
=
output_sizes
)
return
return
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
...
...
vllm/model_executor/models/qwen3_5.py
View file @
1d65283e
...
@@ -30,20 +30,36 @@ from collections.abc import Callable, Iterable
...
@@ -30,20 +30,36 @@ from collections.abc import Callable, Iterable
import
torch
import
torch
from
einops
import
rearrange
from
einops
import
rearrange
from
torch
import
nn
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
SpeculativeConfig
,
VllmConfig
,
VllmConfig
,
get_current_vllm_config
,
)
)
from
vllm.distributed
import
(
from
vllm.distributed
import
(
divide
,
get_pp_group
,
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
(
from
vllm.model_executor.layers.layernorm
import
(
GemmaRMSNorm
as
Qwen3_5RMSNorm
,
GemmaRMSNorm
as
Qwen3_5RMSNorm
,
)
)
from
vllm.model_executor.layers.linear
import
MergedColumnParallelLinear
from
vllm.model_executor.layers.layernorm
import
RMSNormGated
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
(
mamba_v2_sharded_weight_loader
,
)
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateCopyFunc
,
MambaStateCopyFunc
,
MambaStateCopyFuncCalculator
,
MambaStateCopyFuncCalculator
,
...
@@ -57,8 +73,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -57,8 +73,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
default_weight_loader
,
sharded_weight_loader
,
)
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.qwen3_5
import
(
from
vllm.transformers_utils.configs.qwen3_5
import
(
Qwen3_5Config
,
Qwen3_5Config
,
...
@@ -80,6 +99,7 @@ from .interfaces import (
...
@@ -80,6 +99,7 @@ from .interfaces import (
)
)
from
.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3NextMLP
from
.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3NextMLP
from
.qwen3_next
import
(
from
.qwen3_next
import
(
ChunkGatedDeltaRule
,
Qwen3NextAttention
,
Qwen3NextAttention
,
Qwen3NextDecoderLayer
,
Qwen3NextDecoderLayer
,
Qwen3NextGatedDeltaNet
,
Qwen3NextGatedDeltaNet
,
...
@@ -119,29 +139,152 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo):
...
@@ -119,29 +139,152 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo):
class
Qwen3_5GatedDeltaNet
(
Qwen3NextGatedDeltaNet
):
class
Qwen3_5GatedDeltaNet
(
Qwen3NextGatedDeltaNet
):
def
fix_query_key_value_ordering
(
def
__init__
(
self
,
self
,
mixed_qkvz
:
torch
.
Tensor
,
config
:
Qwen3_5TextConfig
|
Qwen3_5MoeTextConfig
,
mixed_ba
:
torch
.
Tensor
,
model_config
:
ModelConfig
|
None
=
None
,
):
cache_config
:
CacheConfig
|
None
=
None
,
raise
NotImplementedError
(
quant_config
:
QuantizationConfig
|
None
=
None
,
"Qwen3.5 Series dont need to fix query key value ordering"
speculative_config
:
SpeculativeConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
(
Qwen3NextGatedDeltaNet
,
self
).
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
hidden_size
=
config
.
hidden_size
self
.
num_v_heads
=
config
.
linear_num_value_heads
self
.
num_k_heads
=
config
.
linear_num_key_heads
self
.
head_k_dim
=
config
.
linear_key_head_dim
self
.
head_v_dim
=
config
.
linear_value_head_dim
self
.
key_dim
=
self
.
head_k_dim
*
self
.
num_k_heads
self
.
value_dim
=
self
.
head_v_dim
*
self
.
num_v_heads
self
.
conv_kernel_size
=
config
.
linear_conv_kernel_dim
self
.
layer_idx
=
extract_layer_index
(
prefix
)
self
.
activation
=
config
.
hidden_act
self
.
act
=
ACT2FN
[
config
.
hidden_act
]
self
.
layer_norm_epsilon
=
config
.
rms_norm_eps
self
.
prefix
=
prefix
self
.
config
=
config
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
speculative_config
=
speculative_config
self
.
num_spec
=
(
self
.
speculative_config
.
num_speculative_tokens
if
self
.
speculative_config
else
0
)
)
def
create_qkvz_proj
(
# QKV
self
,
self
.
conv_dim
=
self
.
key_dim
*
2
+
self
.
value_dim
hidden_size
:
int
,
self
.
conv1d
=
ColumnParallelLinear
(
key_dim
:
int
,
input_size
=
self
.
conv_kernel_size
,
value_dim
:
int
,
output_size
=
self
.
conv_dim
,
quant_config
:
QuantizationConfig
|
None
,
bias
=
False
,
prefix
:
str
,
prefix
=
f
"
{
prefix
}
.conv1d"
,
)
->
MergedColumnParallelLinear
:
)
return
MergedColumnParallelLinear
(
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
input_size
=
hidden_size
,
output_sizes
=
[
key_dim
,
key_dim
,
value_dim
,
value_dim
],
self
.
in_proj_qkv
=
MergedColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_sizes
=
[
self
.
key_dim
,
self
.
key_dim
,
self
.
value_dim
],
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
prefix
=
f
"
{
prefix
}
.in_proj_qkv"
,
)
self
.
in_proj_z
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
value_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_z"
,
)
self
.
in_proj_b
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
num_v_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_b"
,
)
self
.
in_proj_a
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
num_v_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_a"
,
)
query_key_settings
=
(
self
.
key_dim
,
0
,
False
)
value_settings
=
(
self
.
value_dim
,
0
,
False
)
delattr
(
self
.
conv1d
.
weight
,
"weight_loader"
)
set_weight_attrs
(
self
.
conv1d
.
weight
,
{
"weight_loader"
:
mamba_v2_sharded_weight_loader
(
[
query_key_settings
,
query_key_settings
,
value_settings
,
],
self
.
tp_size
,
self
.
tp_rank
,
)
},
)
# selective projection used to make dt, B and C input dependant
# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self
.
dt_bias
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_v_heads
//
self
.
tp_size
),
)
self
.
A_log
=
nn
.
Parameter
(
torch
.
empty
(
divide
(
self
.
num_v_heads
,
self
.
tp_size
),
)
)
set_weight_attrs
(
self
.
A_log
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
set_weight_attrs
(
self
.
dt_bias
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
self
.
norm
=
RMSNormGated
(
self
.
head_v_dim
,
eps
=
self
.
layer_norm_epsilon
,
group_size
=
None
,
norm_before_gate
=
True
,
device
=
current_platform
.
current_device
(),
dtype
=
config
.
dtype
,
)
self
.
out_proj
=
RowParallelLinear
(
self
.
value_dim
,
self
.
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.out_proj"
,
)
self
.
chunk_gated_delta_rule
=
ChunkGatedDeltaRule
()
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
def
fix_query_key_value_ordering
(
self
,
mixed_qkv
,
z
,
b
,
a
,
):
raise
NotImplementedError
(
"Qwen3.5 Series dont need to fix query key value ordering"
)
)
def
forward
(
def
forward
(
...
@@ -160,13 +303,11 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
...
@@ -160,13 +303,11 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
# ============================================================
# ============================================================
# Part 1: Input Projection
# Part 1: Input Projection
# ============================================================
# ============================================================
mixed_qkvz
,
_
=
self
.
in_proj_qkvz
(
hidden_states
)
mixed_qkv
,
_
=
self
.
in_proj_qkv
(
hidden_states
)
qkv_size
=
(
self
.
key_dim
*
2
+
self
.
value_dim
)
//
self
.
tp_size
z
,
_
=
self
.
in_proj_z
(
hidden_states
)
z_size
=
self
.
value_dim
//
self
.
tp_size
mixed_qkv
,
z
=
mixed_qkvz
.
split
([
qkv_size
,
z_size
],
dim
=-
1
)
z
=
z
.
reshape
(
z
.
size
(
0
),
-
1
,
self
.
head_v_dim
)
z
=
z
.
reshape
(
z
.
size
(
0
),
-
1
,
self
.
head_v_dim
)
b
a
,
_
=
self
.
in_proj_b
a
(
hidden_states
)
b
,
_
=
self
.
in_proj_b
(
hidden_states
)
b
,
a
=
ba
.
chunk
(
2
,
dim
=-
1
)
a
,
_
=
self
.
in_proj_a
(
hidden_states
)
b
=
b
.
contiguous
()
b
=
b
.
contiguous
()
a
=
a
.
contiguous
()
a
=
a
.
contiguous
()
...
@@ -365,18 +506,11 @@ class Qwen3_5Model(Qwen3NextModel):
...
@@ -365,18 +506,11 @@ class Qwen3_5Model(Qwen3NextModel):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
# self attention
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
# mlp
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
# GDN
(
"in_proj_qkvz"
,
"in_proj_qkv"
,
(
0
,
1
,
2
)),
(
"in_proj_qkvz"
,
"in_proj_z"
,
3
),
(
"in_proj_ba"
,
"in_proj_b"
,
0
),
(
"in_proj_ba"
,
"in_proj_a"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
...
...
vllm/model_executor/models/qwen3_next.py
View file @
1d65283e
...
@@ -44,7 +44,6 @@ from vllm.model_executor.layers.layernorm import (
...
@@ -44,7 +44,6 @@ from vllm.model_executor.layers.layernorm import (
from
vllm.model_executor.layers.layernorm
import
RMSNormGated
from
vllm.model_executor.layers.layernorm
import
RMSNormGated
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
,
RowParallelLinear
,
...
@@ -407,19 +406,19 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -407,19 +406,19 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
# projection of the input hidden states
# projection of the input hidden states
# Qwen3-Next and Qwen3.5 has a different qkv_proj layout,
self
.
projection_size_qkvz
=
self
.
key_dim
*
2
+
self
.
value_dim
*
2
# we need to create qkvz_proj adaptively here.
self
.
projection_size_ba
=
self
.
num_v_heads
*
2
self
.
in_proj_qkvz
=
self
.
create_qkvz_proj
(
self
.
in_proj_qkvz
=
ColumnParallelLinear
(
hidden
_size
=
self
.
hidden_size
,
input
_size
=
self
.
hidden_size
,
key_dim
=
self
.
key_dim
,
output_size
=
self
.
projection_size_qkvz
,
value_dim
=
self
.
value_dim
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_qkvz"
,
prefix
=
f
"
{
prefix
}
.in_proj_qkvz"
,
)
)
# ba_proj doesn't support blockwise fp8 quantization.
# ba_proj doesn't support blockwise fp8 quantization.
self
.
in_proj_ba
=
Merged
ColumnParallelLinear
(
self
.
in_proj_ba
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
input_size
=
self
.
hidden_size
,
output_size
s
=
[
self
.
num_v_heads
]
*
2
,
output_size
=
self
.
projection_size_ba
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_ba"
,
prefix
=
f
"
{
prefix
}
.in_proj_ba"
,
...
@@ -485,26 +484,10 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -485,26 +484,10 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
compilation_config
.
static_forward_context
[
prefix
]
=
self
def
create_qkvz_proj
(
self
,
hidden_size
:
int
,
key_dim
:
int
,
value_dim
:
int
,
quant_config
:
QuantizationConfig
|
None
,
prefix
:
str
,
)
->
MergedColumnParallelLinear
:
return
MergedColumnParallelLinear
(
input_size
=
hidden_size
,
output_sizes
=
[
sum
((
key_dim
,
key_dim
,
value_dim
)),
value_dim
],
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_qkvz"
,
)
def
fix_query_key_value_ordering
(
def
fix_query_key_value_ordering
(
self
,
self
,
mixed_qkvz
:
torch
.
Tensor
,
mixed_qkvz
,
mixed_ba
:
torch
.
Tensor
,
mixed_ba
,
):
):
"""
"""
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment