Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3bb4e431
Unverified
Commit
3bb4e431
authored
Feb 16, 2026
by
Isotr0py
Committed by
GitHub
Feb 16, 2026
Browse files
[Models] Fuse Qwen3.5 GDN's qkvz_proj and ba_proj (#34492)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
08f8c198
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
87 additions
and
182 deletions
+87
-182
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+28
-6
vllm/model_executor/models/qwen3_5.py
vllm/model_executor/models/qwen3_5.py
+32
-166
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+27
-10
No files found.
vllm/model_executor/layers/linear.py
View file @
3bb4e431
...
@@ -685,8 +685,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -685,8 +685,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self
,
self
,
param
:
Parameter
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
int
|
None
=
None
,
loaded_shard_id
:
tuple
[
int
,
...]
|
int
|
None
=
None
,
):
):
if
isinstance
(
loaded_shard_id
,
tuple
):
raise
NotImplementedError
(
"Shard id with multiple indices is not supported in weight_loader, "
"please use weight_loader_v2 instead."
)
# Special case for GGUF
# Special case for GGUF
# initialize GGUF param after we know the quantize type
# initialize GGUF param after we know the quantize type
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
...
@@ -825,7 +830,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -825,7 +830,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
def
_load_fused_module_from_checkpoint
(
def
_load_fused_module_from_checkpoint
(
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
,
output_sizes
:
list
[
int
]
|
None
=
None
,
):
):
"""
"""
Handle special case for models where MLP layers are already
Handle special case for models where MLP layers are already
...
@@ -839,7 +847,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -839,7 +847,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
current_shard_offset
=
0
current_shard_offset
=
0
shard_offsets
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
shard_offsets
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
for
i
,
output_size
in
enumerate
(
self
.
output_sizes
):
output_sizes
=
output_sizes
or
self
.
output_sizes
for
i
,
output_size
in
enumerate
(
output_sizes
):
shard_offsets
.
append
((
i
,
current_shard_offset
,
output_size
))
shard_offsets
.
append
((
i
,
current_shard_offset
,
output_size
))
current_shard_offset
+=
output_size
current_shard_offset
+=
output_size
...
@@ -864,17 +873,30 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -864,17 +873,30 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self
,
self
,
param
:
BasevLLMParameter
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
int
|
None
=
None
,
loaded_shard_id
:
tuple
[
int
,
...]
|
int
|
None
=
None
,
):
):
if
loaded_shard_id
is
None
:
if
loaded_shard_id
is
None
or
isinstance
(
loaded_shard_id
,
tuple
)
:
if
isinstance
(
param
,
PerTensorScaleParameter
):
if
isinstance
(
param
,
PerTensorScaleParameter
):
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
0
)
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
0
)
return
return
elif
type
(
param
)
in
(
RowvLLMParameter
,
BasevLLMParameter
):
elif
type
(
param
)
in
(
RowvLLMParameter
,
BasevLLMParameter
):
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
)
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
)
return
return
output_sizes
=
(
[
self
.
output_sizes
[
idx
]
for
idx
in
loaded_shard_id
]
if
loaded_shard_id
else
None
)
if
isinstance
(
param
,
BlockQuantScaleParameter
):
weight_block_size
=
getattr
(
self
,
"weight_block_size"
,
None
)
output_sizes
=
[
adjust_block_scale_shard
(
weight_block_size
,
size
,
0
)[
0
]
for
size
in
(
output_sizes
or
self
.
output_sizes
)
]
# TODO: @dsikka - move to parameter.py
# TODO: @dsikka - move to parameter.py
self
.
_load_fused_module_from_checkpoint
(
param
,
loaded_weight
)
self
.
_load_fused_module_from_checkpoint
(
param
,
loaded_weight
,
output_sizes
=
output_sizes
)
return
return
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
...
...
vllm/model_executor/models/qwen3_5.py
View file @
3bb4e431
...
@@ -30,36 +30,20 @@ from collections.abc import Callable, Iterable
...
@@ -30,36 +30,20 @@ from collections.abc import Callable, Iterable
import
torch
import
torch
from
einops
import
rearrange
from
einops
import
rearrange
from
torch
import
nn
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
SpeculativeConfig
,
VllmConfig
,
VllmConfig
,
get_current_vllm_config
,
)
)
from
vllm.distributed
import
(
from
vllm.distributed
import
(
divide
,
get_pp_group
,
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
(
from
vllm.model_executor.layers.layernorm
import
(
GemmaRMSNorm
as
Qwen3_5RMSNorm
,
GemmaRMSNorm
as
Qwen3_5RMSNorm
,
)
)
from
vllm.model_executor.layers.layernorm
import
RMSNormGated
from
vllm.model_executor.layers.linear
import
MergedColumnParallelLinear
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
(
mamba_v2_sharded_weight_loader
,
)
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateCopyFunc
,
MambaStateCopyFunc
,
MambaStateCopyFuncCalculator
,
MambaStateCopyFuncCalculator
,
...
@@ -73,11 +57,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -73,11 +57,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
default_weight_loader
,
sharded_weight_loader
,
)
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.qwen3_5
import
(
from
vllm.transformers_utils.configs.qwen3_5
import
(
Qwen3_5Config
,
Qwen3_5Config
,
...
@@ -99,7 +80,6 @@ from .interfaces import (
...
@@ -99,7 +80,6 @@ from .interfaces import (
)
)
from
.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3NextMLP
from
.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3NextMLP
from
.qwen3_next
import
(
from
.qwen3_next
import
(
ChunkGatedDeltaRule
,
Qwen3NextAttention
,
Qwen3NextAttention
,
Qwen3NextDecoderLayer
,
Qwen3NextDecoderLayer
,
Qwen3NextGatedDeltaNet
,
Qwen3NextGatedDeltaNet
,
...
@@ -139,154 +119,31 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo):
...
@@ -139,154 +119,31 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo):
class
Qwen3_5GatedDeltaNet
(
Qwen3NextGatedDeltaNet
):
class
Qwen3_5GatedDeltaNet
(
Qwen3NextGatedDeltaNet
):
def
__init__
(
self
,
config
:
Qwen3_5TextConfig
|
Qwen3_5MoeTextConfig
,
model_config
:
ModelConfig
|
None
=
None
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
speculative_config
:
SpeculativeConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
(
Qwen3NextGatedDeltaNet
,
self
).
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
hidden_size
=
config
.
hidden_size
self
.
num_v_heads
=
config
.
linear_num_value_heads
self
.
num_k_heads
=
config
.
linear_num_key_heads
self
.
head_k_dim
=
config
.
linear_key_head_dim
self
.
head_v_dim
=
config
.
linear_value_head_dim
self
.
key_dim
=
self
.
head_k_dim
*
self
.
num_k_heads
self
.
value_dim
=
self
.
head_v_dim
*
self
.
num_v_heads
self
.
conv_kernel_size
=
config
.
linear_conv_kernel_dim
self
.
layer_idx
=
extract_layer_index
(
prefix
)
self
.
activation
=
config
.
hidden_act
self
.
act
=
ACT2FN
[
config
.
hidden_act
]
self
.
layer_norm_epsilon
=
config
.
rms_norm_eps
self
.
prefix
=
prefix
self
.
config
=
config
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
speculative_config
=
speculative_config
self
.
num_spec
=
(
self
.
speculative_config
.
num_speculative_tokens
if
self
.
speculative_config
else
0
)
# QKV
self
.
conv_dim
=
self
.
key_dim
*
2
+
self
.
value_dim
self
.
conv1d
=
ColumnParallelLinear
(
input_size
=
self
.
conv_kernel_size
,
output_size
=
self
.
conv_dim
,
bias
=
False
,
prefix
=
f
"
{
prefix
}
.conv1d"
,
)
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
self
.
in_proj_qkv
=
MergedColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_sizes
=
[
self
.
key_dim
,
self
.
key_dim
,
self
.
value_dim
],
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_qkv"
,
)
self
.
in_proj_z
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
value_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_z"
,
)
self
.
in_proj_b
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
num_v_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_b"
,
)
self
.
in_proj_a
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
num_v_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_a"
,
)
query_key_settings
=
(
self
.
key_dim
,
0
,
False
)
value_settings
=
(
self
.
value_dim
,
0
,
False
)
delattr
(
self
.
conv1d
.
weight
,
"weight_loader"
)
set_weight_attrs
(
self
.
conv1d
.
weight
,
{
"weight_loader"
:
mamba_v2_sharded_weight_loader
(
[
query_key_settings
,
query_key_settings
,
value_settings
,
],
self
.
tp_size
,
self
.
tp_rank
,
)
},
)
# selective projection used to make dt, B and C input dependant
# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self
.
dt_bias
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_v_heads
//
self
.
tp_size
),
)
self
.
A_log
=
nn
.
Parameter
(
torch
.
empty
(
divide
(
self
.
num_v_heads
,
self
.
tp_size
),
)
)
set_weight_attrs
(
self
.
A_log
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
set_weight_attrs
(
self
.
dt_bias
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
self
.
norm
=
RMSNormGated
(
self
.
head_v_dim
,
eps
=
self
.
layer_norm_epsilon
,
group_size
=
None
,
norm_before_gate
=
True
,
device
=
current_platform
.
current_device
(),
dtype
=
config
.
dtype
,
)
self
.
out_proj
=
RowParallelLinear
(
self
.
value_dim
,
self
.
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.out_proj"
,
)
self
.
chunk_gated_delta_rule
=
ChunkGatedDeltaRule
()
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
def
fix_query_key_value_ordering
(
def
fix_query_key_value_ordering
(
self
,
self
,
mixed_qkv
,
mixed_qkvz
:
torch
.
Tensor
,
z
,
mixed_ba
:
torch
.
Tensor
,
b
,
a
,
):
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"Qwen3.5 Series dont need to fix query key value ordering"
"Qwen3.5 Series dont need to fix query key value ordering"
)
)
def
create_qkvz_proj
(
self
,
hidden_size
:
int
,
key_dim
:
int
,
value_dim
:
int
,
quant_config
:
QuantizationConfig
|
None
,
prefix
:
str
,
)
->
MergedColumnParallelLinear
:
return
MergedColumnParallelLinear
(
input_size
=
hidden_size
,
output_sizes
=
[
key_dim
,
key_dim
,
value_dim
,
value_dim
],
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
prefix
,
)
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -303,11 +160,13 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
...
@@ -303,11 +160,13 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
# ============================================================
# ============================================================
# Part 1: Input Projection
# Part 1: Input Projection
# ============================================================
# ============================================================
mixed_qkv
,
_
=
self
.
in_proj_qkv
(
hidden_states
)
mixed_qkvz
,
_
=
self
.
in_proj_qkvz
(
hidden_states
)
z
,
_
=
self
.
in_proj_z
(
hidden_states
)
qkv_size
=
(
self
.
key_dim
*
2
+
self
.
value_dim
)
//
self
.
tp_size
z_size
=
self
.
value_dim
//
self
.
tp_size
mixed_qkv
,
z
=
mixed_qkvz
.
split
([
qkv_size
,
z_size
],
dim
=-
1
)
z
=
z
.
reshape
(
z
.
size
(
0
),
-
1
,
self
.
head_v_dim
)
z
=
z
.
reshape
(
z
.
size
(
0
),
-
1
,
self
.
head_v_dim
)
b
,
_
=
self
.
in_proj_b
(
hidden_states
)
b
a
,
_
=
self
.
in_proj_b
a
(
hidden_states
)
a
,
_
=
self
.
in_proj_a
(
hidden_states
)
b
,
a
=
ba
.
chunk
(
2
,
dim
=-
1
)
b
=
b
.
contiguous
()
b
=
b
.
contiguous
()
a
=
a
.
contiguous
()
a
=
a
.
contiguous
()
...
@@ -506,11 +365,18 @@ class Qwen3_5Model(Qwen3NextModel):
...
@@ -506,11 +365,18 @@ class Qwen3_5Model(Qwen3NextModel):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
# self attention
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
# mlp
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
# GDN
(
"in_proj_qkvz"
,
"in_proj_qkv"
,
(
0
,
1
,
2
)),
(
"in_proj_qkvz"
,
"in_proj_z"
,
3
),
(
"in_proj_ba"
,
"in_proj_b"
,
0
),
(
"in_proj_ba"
,
"in_proj_a"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
...
...
vllm/model_executor/models/qwen3_next.py
View file @
3bb4e431
...
@@ -44,6 +44,7 @@ from vllm.model_executor.layers.layernorm import (
...
@@ -44,6 +44,7 @@ from vllm.model_executor.layers.layernorm import (
from
vllm.model_executor.layers.layernorm
import
RMSNormGated
from
vllm.model_executor.layers.layernorm
import
RMSNormGated
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
,
RowParallelLinear
,
...
@@ -406,19 +407,19 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -406,19 +407,19 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
# projection of the input hidden states
# projection of the input hidden states
self
.
projection_size_qkvz
=
self
.
key_dim
*
2
+
self
.
value_dim
*
2
# Qwen3-Next and Qwen3.5 has a different qkv_proj layout,
self
.
projection_size_ba
=
self
.
num_v_heads
*
2
# we need to create qkvz_proj adaptively here.
self
.
in_proj_qkvz
=
ColumnParallelLinear
(
self
.
in_proj_qkvz
=
self
.
create_qkvz_proj
(
input
_size
=
self
.
hidden_size
,
hidden
_size
=
self
.
hidden_size
,
output_size
=
self
.
projection_size_qkvz
,
key_dim
=
self
.
key_dim
,
bias
=
False
,
value_dim
=
self
.
value_dim
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_qkvz"
,
prefix
=
f
"
{
prefix
}
.in_proj_qkvz"
,
)
)
# ba_proj doesn't support blockwise fp8 quantization.
# ba_proj doesn't support blockwise fp8 quantization.
self
.
in_proj_ba
=
ColumnParallelLinear
(
self
.
in_proj_ba
=
Merged
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
input_size
=
self
.
hidden_size
,
output_size
=
self
.
projection_size_ba
,
output_size
s
=
[
self
.
num_v_heads
]
*
2
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_ba"
,
prefix
=
f
"
{
prefix
}
.in_proj_ba"
,
...
@@ -484,10 +485,26 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -484,10 +485,26 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
compilation_config
.
static_forward_context
[
prefix
]
=
self
def
create_qkvz_proj
(
self
,
hidden_size
:
int
,
key_dim
:
int
,
value_dim
:
int
,
quant_config
:
QuantizationConfig
|
None
,
prefix
:
str
,
)
->
MergedColumnParallelLinear
:
return
MergedColumnParallelLinear
(
input_size
=
hidden_size
,
output_sizes
=
[
sum
((
key_dim
,
key_dim
,
value_dim
)),
value_dim
],
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_qkvz"
,
)
def
fix_query_key_value_ordering
(
def
fix_query_key_value_ordering
(
self
,
self
,
mixed_qkvz
,
mixed_qkvz
:
torch
.
Tensor
,
mixed_ba
,
mixed_ba
:
torch
.
Tensor
,
):
):
"""
"""
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment