Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
38d80967
Commit
38d80967
authored
Sep 12, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori
parents
33650733
880c741b
Changes
544
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2397 additions
and
29 deletions
+2397
-29
vllm/lora/models.py
vllm/lora/models.py
+1
-1
vllm/lora/peft_helper.py
vllm/lora/peft_helper.py
+1
-1
vllm/lora/punica_wrapper/punica_xpu.py
vllm/lora/punica_wrapper/punica_xpu.py
+10
-3
vllm/lora/utils.py
vllm/lora/utils.py
+7
-7
vllm/lora/worker_manager.py
vllm/lora/worker_manager.py
+1
-1
vllm/model_executor/custom_op.py
vllm/model_executor/custom_op.py
+0
-7
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+5
-9
vllm/model_executor/layers/fla/__init__.py
vllm/model_executor/layers/fla/__init__.py
+8
-0
vllm/model_executor/layers/fla/ops/__init__.py
vllm/model_executor/layers/fla/ops/__init__.py
+17
-0
vllm/model_executor/layers/fla/ops/chunk.py
vllm/model_executor/layers/fla/ops/chunk.py
+225
-0
vllm/model_executor/layers/fla/ops/chunk_delta_h.py
vllm/model_executor/layers/fla/ops/chunk_delta_h.py
+290
-0
vllm/model_executor/layers/fla/ops/chunk_o.py
vllm/model_executor/layers/fla/ops/chunk_o.py
+177
-0
vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
+140
-0
vllm/model_executor/layers/fla/ops/cumsum.py
vllm/model_executor/layers/fla/ops/cumsum.py
+226
-0
vllm/model_executor/layers/fla/ops/fused_recurrent.py
vllm/model_executor/layers/fla/ops/fused_recurrent.py
+366
-0
vllm/model_executor/layers/fla/ops/index.py
vllm/model_executor/layers/fla/ops/index.py
+39
-0
vllm/model_executor/layers/fla/ops/l2norm.py
vllm/model_executor/layers/fla/ops/l2norm.py
+143
-0
vllm/model_executor/layers/fla/ops/layernorm_guard.py
vllm/model_executor/layers/fla/ops/layernorm_guard.py
+337
-0
vllm/model_executor/layers/fla/ops/op.py
vllm/model_executor/layers/fla/ops/op.py
+39
-0
vllm/model_executor/layers/fla/ops/solve_tril.py
vllm/model_executor/layers/fla/ops/solve_tril.py
+365
-0
No files found.
Too many changes to show.
To preserve performance only
544 of 544+
files are displayed.
Plain diff
Email patch
vllm/lora/models.py
View file @
38d80967
...
...
@@ -16,7 +16,7 @@ from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
from
vllm.adapter_commons.utils
import
(
add_adapter
,
deactivate_adapter
,
get_adapter
,
list_adapters
,
remove_adapter
,
set_adapter_mapping
)
from
vllm.config
import
LoRAConfig
from
vllm.config
.lora
import
LoRAConfig
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
BaseLayerWithLoRA
,
LoRAMapping
from
vllm.lora.lora
import
LoRALayerWeights
,
PackedLoRALayerWeights
...
...
vllm/lora/peft_helper.py
View file @
38d80967
...
...
@@ -9,7 +9,7 @@ import os
from
dataclasses
import
MISSING
,
dataclass
,
field
,
fields
from
typing
import
Literal
,
Optional
,
Union
from
vllm.config
import
LoRAConfig
from
vllm.config
.lora
import
LoRAConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
...
...
vllm/lora/punica_wrapper/punica_xpu.py
View file @
38d80967
...
...
@@ -225,6 +225,13 @@ class PunicaWrapperXPU(PunicaWrapperBase):
add_inputs
=
True
,
**
kwargs
)
@
property
def
sampler_indices_padded
(
self
)
->
torch
.
Tensor
:
"""
This property provides access to padded sampler indices.
"""
return
self
.
_sampler_indices_padded
[:]
def
add_lora_logits
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
@@ -259,11 +266,11 @@ class PunicaWrapperXPU(PunicaWrapperBase):
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
bgmv_shrink
(
x
,
lora_a_stacked
,
buffer
,
self
.
sampler_indices
,
scale
)
sampler_indices
=
torch
.
narrow
(
self
.
_sampler_indices
,
0
,
0
,
x
.
size
(
0
))
bgmv_shrink
(
x
,
lora_a_stacked
,
buffer
,
sampler_indices
,
scale
)
bgmv_expand
(
buffer
,
lora_b_stacked
,
y
,
self
.
sampler_indices
,
sampler_indices
,
add_inputs
=
True
)
return
y
.
view_as
(
y_org
)
vllm/lora/utils.py
View file @
38d80967
...
...
@@ -11,23 +11,23 @@ from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
LoRAConfig
from
vllm.config
.lora
import
LoRAConfig
from
vllm.logger
import
init_logger
from
vllm.lora.fully_sharded_layers
import
(
ColumnParallelLinearWithShardedLoRA
,
MergedColumnParallelLinearWithShardedLoRA
,
MergedQKVParallelLinearWithShardedLoRA
,
QKVParallelLinearWithShardedLoRA
,
RowParallelLinearWithShardedLoRA
)
# being imported for _all_lora_classes below
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.lora.layers
import
(
BaseLayerWithLoRA
,
ColumnParallelLinearWithLoRA
,
ColumnParallelLinearWithShardedLoRA
,
LogitsProcessorWithLoRA
,
MergedColumnParallelLinearWithLoRA
,
MergedColumnParallelLinearWithShardedLoRA
,
MergedQKVParallelLinearWithLoRA
,
MergedQKVParallelLinearWithShardedLoRA
,
QKVParallelLinearWithLoRA
,
QKVParallelLinearWithShardedLoRA
,
ReplicatedLinearWithLoRA
,
RowParallelLinearWithLoRA
,
RowParallelLinearWithShardedLoRA
,
VocabParallelEmbeddingWithLoRA
)
from
vllm.model_executor.layers.linear
import
LinearBase
...
...
@@ -239,7 +239,7 @@ def get_adapter_absolute_path(lora_path: str) -> str:
except
(
HfHubHTTPError
,
RepositoryNotFoundError
,
EntryNotFoundError
,
HFValidationError
):
# Handle errors that may occur during the download
# Return original path instead
instead
of throwing error here
# Return original path instead of throwing error here
logger
.
exception
(
"Error downloading the HuggingFace model"
)
return
lora_path
...
...
vllm/lora/worker_manager.py
View file @
38d80967
...
...
@@ -11,7 +11,7 @@ from vllm.adapter_commons.utils import (add_adapter_worker,
list_adapters_worker
,
set_active_adapters_worker
)
from
vllm.adapter_commons.worker_manager
import
AbstractWorkerManager
from
vllm.config
import
LoRAConfig
from
vllm.config
.lora
import
LoRAConfig
from
vllm.logger
import
init_logger
from
vllm.lora.models
import
(
LoRAModel
,
LoRAModelManager
,
LRUCacheLoRAModelManager
,
create_lora_manager
)
...
...
vllm/model_executor/custom_op.py
View file @
38d80967
...
...
@@ -73,11 +73,6 @@ class CustomOp(nn.Module):
# NOTE(woosuk): This is a placeholder for future extensions.
return
self
.
forward_native
(
*
args
,
**
kwargs
)
def
forward_neuron
(
self
,
*
args
,
**
kwargs
):
# By default, we assume that Neuron ops are compatible with the
# PyTorch-native implementation.
return
self
.
forward_native
(
*
args
,
**
kwargs
)
def
forward_oot
(
self
,
*
args
,
**
kwargs
):
# By default, we assume that OOT ops are compatible with the
# PyTorch-native implementation.
...
...
@@ -105,8 +100,6 @@ class CustomOp(nn.Module):
return
self
.
forward_tpu
elif
current_platform
.
is_xpu
():
return
self
.
forward_xpu
elif
current_platform
.
is_neuron
():
return
self
.
forward_neuron
elif
current_platform
.
is_out_of_tree
():
return
self
.
forward_oot
else
:
...
...
vllm/model_executor/layers/activation.py
View file @
38d80967
...
...
@@ -95,13 +95,6 @@ class SiluAndMul(CustomOp):
self
.
op
(
out
,
x
)
return
out
def
forward_neuron
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
x_reshaped
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
s
=
x_reshaped
[:,
:
d
]
*
F
.
sigmoid
(
x_reshaped
[:,
:
d
])
result
=
s
*
x_reshaped
[:,
d
:]
return
result
.
view
(
*
x
.
shape
[:
-
1
],
d
)
@
CustomOp
.
register
(
"mul_and_silu"
)
class
MulAndSilu
(
CustomOp
):
...
...
@@ -362,7 +355,7 @@ class ReLUSquaredActivation(CustomOp):
return
torch
.
square
(
F
.
relu
(
x
))
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
#TODO : implement cuda ke
n
rels
#TODO : implement cuda ker
n
els
return
self
.
forward_native
(
x
)
...
...
@@ -461,7 +454,7 @@ class XIELU(CustomOp):
)
return
result
.
view
(
original_shape
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
_native
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
_xielu_cuda_obj
is
not
None
and
input
.
is_cuda
:
if
not
torch
.
_dynamo
.
is_compiling
():
return
self
.
_xielu_cuda_fn
(
input
)
...
...
@@ -471,6 +464,9 @@ class XIELU(CustomOp):
)
return
self
.
_xielu_python
(
input
)
def
forward_cuda
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
forward_native
(
input
)
class
ScaledActivation
(
nn
.
Module
):
"""An activation function with post-scale parameters.
...
...
vllm/model_executor/layers/fla/__init__.py
0 → 100644
View file @
38d80967
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
vllm/model_executor/layers/fla/ops/__init__.py
0 → 100644
View file @
38d80967
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from
.chunk
import
chunk_gated_delta_rule
from
.fused_recurrent
import
fused_recurrent_gated_delta_rule
from
.layernorm_guard
import
RMSNormGated
__all__
=
[
"RMSNormGated"
,
"chunk_gated_delta_rule"
,
"fused_recurrent_gated_delta_rule"
,
]
vllm/model_executor/layers/fla/ops/chunk.py
0 → 100644
View file @
38d80967
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import
warnings
from
typing
import
Optional
import
torch
from
einops
import
rearrange
from
.chunk_delta_h
import
chunk_gated_delta_rule_fwd_h
from
.chunk_o
import
chunk_fwd_o
from
.chunk_scaled_dot_kkt
import
chunk_scaled_dot_kkt_fwd
from
.cumsum
import
chunk_local_cumsum
from
.l2norm
import
l2norm_fwd
from
.solve_tril
import
solve_tril
from
.utils
import
SUPPRESS_LEVEL
,
input_guard
from
.wy_fast
import
recompute_w_u_fwd
def
chunk_gated_delta_rule_fwd
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
,
initial_state
:
torch
.
Tensor
,
output_final_state
:
bool
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
):
g
=
chunk_local_cumsum
(
g
,
chunk_size
=
64
,
cu_seqlens
=
cu_seqlens
)
# obtain WY representation. u is actually the new v.
A
=
chunk_scaled_dot_kkt_fwd
(
k
=
k
,
beta
=
beta
,
g_cumsum
=
g
,
cu_seqlens
=
cu_seqlens
,
output_dtype
=
torch
.
float32
)
A
=
solve_tril
(
A
=
A
,
cu_seqlens
=
cu_seqlens
,
output_dtype
=
k
.
dtype
)
w
,
u
=
recompute_w_u_fwd
(
k
=
k
,
v
=
v
,
beta
=
beta
,
A
=
A
,
g_cumsum
=
g
,
cu_seqlens
=
cu_seqlens
,
)
h
,
v_new
,
final_state
=
chunk_gated_delta_rule_fwd_h
(
k
=
k
,
w
=
w
,
u
=
u
,
g
=
g
,
initial_state
=
initial_state
,
output_final_state
=
output_final_state
,
cu_seqlens
=
cu_seqlens
,
)
o
=
chunk_fwd_o
(
q
=
q
,
k
=
k
,
v
=
v_new
,
h
=
h
,
g
=
g
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
)
if
SUPPRESS_LEVEL
<
3
:
return
g
,
o
,
A
,
final_state
,
None
,
None
,
None
elif
SUPPRESS_LEVEL
>=
3
:
return
g
,
o
,
A
,
final_state
,
w
,
h
,
v_new
class
ChunkGatedDeltaRuleFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
input_guard
@
torch
.
amp
.
custom_fwd
(
device_type
=
'cuda'
)
def
forward
(
ctx
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
,
initial_state
:
torch
.
Tensor
,
output_final_state
:
bool
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_qk_l2norm_in_kernel
:
bool
=
False
):
if
use_qk_l2norm_in_kernel
:
q
=
l2norm_fwd
(
q
)
k
=
l2norm_fwd
(
k
)
g
,
o
,
A
,
final_state
,
w
,
h
,
v_new
=
chunk_gated_delta_rule_fwd
(
q
=
q
,
k
=
k
,
v
=
v
,
g
=
g
,
beta
=
beta
,
scale
=
scale
,
initial_state
=
initial_state
,
output_final_state
=
output_final_state
,
cu_seqlens
=
cu_seqlens
,
)
ctx
.
scale
=
scale
ctx
.
use_qk_l2norm_in_kernel
=
use_qk_l2norm_in_kernel
return
o
.
to
(
q
.
dtype
),
final_state
@
torch
.
compiler
.
disable
def
chunk_gated_delta_rule
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
=
None
,
initial_state
:
torch
.
Tensor
=
None
,
output_final_state
:
bool
=
False
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_first
:
bool
=
False
,
use_qk_l2norm_in_kernel
:
bool
=
False
):
r
"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g (torch.Tensor):
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
beta (torch.Tensor):
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
scale (Optional[int]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
head_first (Optional[bool]):
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
Default: `False`.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
final_state (torch.Tensor):
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
# inputs with equal lengths
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
>>> o, ht = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o_var, ht_var = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens
)
"""
assert
q
.
dtype
==
k
.
dtype
==
v
.
dtype
assert
q
.
dtype
!=
torch
.
float32
,
"ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert
len
(
beta
.
shape
)
==
3
,
"beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
if
head_first
:
raise
DeprecationWarning
(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead."
,
stacklevel
=
2
)
q
,
k
,
v
,
beta
,
g
=
map
(
lambda
x
:
rearrange
(
x
,
'b h t ... -> b t h ...'
),
(
q
,
k
,
v
,
beta
,
g
))
if
not
head_first
and
q
.
shape
[
1
]
<
q
.
shape
[
2
]:
warnings
.
warn
(
f
"Input tensor shape suggests potential format mismatch: seq_len (
{
q
.
shape
[
1
]
}
) < num_heads (
{
q
.
shape
[
2
]
}
). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
,
stacklevel
=
2
)
if
cu_seqlens
is
not
None
:
if
q
.
shape
[
0
]
!=
1
:
raise
ValueError
(
f
"The batch size is expected to be 1 rather than
{
q
.
shape
[
0
]
}
when using `cu_seqlens`."
f
"Please flatten variable-length inputs before processing."
)
if
initial_state
is
not
None
and
initial_state
.
shape
[
0
]
!=
len
(
cu_seqlens
)
-
1
:
raise
ValueError
(
f
"The number of initial states is expected to be equal to the number of input sequences, "
f
"i.e.,
{
len
(
cu_seqlens
)
-
1
}
rather than
{
initial_state
.
shape
[
0
]
}
."
)
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**-
0.5
o
,
final_state
=
ChunkGatedDeltaRuleFunction
.
apply
(
q
,
k
,
v
,
g
,
beta
,
scale
,
initial_state
,
output_final_state
,
cu_seqlens
,
use_qk_l2norm_in_kernel
)
if
head_first
:
o
=
rearrange
(
o
,
'b t h ... -> b h t ...'
)
return
o
,
final_state
vllm/model_executor/layers/fla/ops/chunk_delta_h.py
0 → 100644
View file @
38d80967
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from
typing
import
Optional
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
.index
import
prepare_chunk_indices
,
prepare_chunk_offsets
from
.op
import
exp
from
.utils
import
is_nvidia_hopper
,
use_cuda_graph
NUM_WARPS
=
[
2
,
4
]
if
is_nvidia_hopper
else
[
2
,
4
,
8
,
16
]
@
triton
.
heuristics
({
'USE_G'
:
lambda
args
:
args
[
'g'
]
is
not
None
,
'USE_INITIAL_STATE'
:
lambda
args
:
args
[
'h0'
]
is
not
None
,
'STORE_FINAL_STATE'
:
lambda
args
:
args
[
'ht'
]
is
not
None
,
'SAVE_NEW_VALUE'
:
lambda
args
:
args
[
'v_new'
]
is
not
None
,
'IS_VARLEN'
:
lambda
args
:
args
[
'cu_seqlens'
]
is
not
None
,
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BV'
:
BV
},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
2
,
4
]
for
num_stages
in
[
2
,
3
,
4
]
for
BV
in
[
32
,
64
]
],
key
=
[
'H'
,
'K'
,
'V'
,
'BT'
,
'USE_G'
],
use_cuda_graph
=
use_cuda_graph
,
)
@
triton
.
jit
(
do_not_specialize
=
[
'T'
])
def
chunk_gated_delta_rule_fwd_kernel_h_blockdim64
(
k
,
v
,
w
,
v_new
,
g
,
h
,
h0
,
ht
,
cu_seqlens
,
chunk_offsets
,
T
,
H
:
tl
.
constexpr
,
Hg
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_G
:
tl
.
constexpr
,
USE_INITIAL_STATE
:
tl
.
constexpr
,
STORE_FINAL_STATE
:
tl
.
constexpr
,
SAVE_NEW_VALUE
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_v
,
i_nh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_n
,
i_h
=
i_nh
//
H
,
i_nh
%
H
if
IS_VARLEN
:
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
NT
=
tl
.
cdiv
(
T
,
BT
)
boh
=
tl
.
load
(
chunk_offsets
+
i_n
).
to
(
tl
.
int32
)
else
:
bos
,
eos
=
i_n
*
T
,
i_n
*
T
+
T
NT
=
tl
.
cdiv
(
T
,
BT
)
boh
=
i_n
*
NT
# [BK, BV]
b_h1
=
tl
.
zeros
([
64
,
BV
],
dtype
=
tl
.
float32
)
if
K
>
64
:
b_h2
=
tl
.
zeros
([
64
,
BV
],
dtype
=
tl
.
float32
)
if
K
>
128
:
b_h3
=
tl
.
zeros
([
64
,
BV
],
dtype
=
tl
.
float32
)
if
K
>
192
:
b_h4
=
tl
.
zeros
([
64
,
BV
],
dtype
=
tl
.
float32
)
# calculate offset
h
+=
(
boh
*
H
+
i_h
)
*
K
*
V
v
+=
(
bos
*
H
+
i_h
)
*
V
k
+=
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
w
+=
(
bos
*
H
+
i_h
)
*
K
if
SAVE_NEW_VALUE
:
v_new
+=
(
bos
*
H
+
i_h
)
*
V
stride_v
=
H
*
V
stride_h
=
H
*
K
*
V
stride_k
=
Hg
*
K
stride_w
=
H
*
K
if
USE_INITIAL_STATE
:
h0
=
h0
+
i_nh
*
K
*
V
if
STORE_FINAL_STATE
:
ht
=
ht
+
i_nh
*
K
*
V
# load initial state
if
USE_INITIAL_STATE
:
p_h0_1
=
tl
.
make_block_ptr
(
h0
,
(
K
,
V
),
(
V
,
1
),
(
0
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
b_h1
+=
tl
.
load
(
p_h0_1
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
K
>
64
:
p_h0_2
=
tl
.
make_block_ptr
(
h0
,
(
K
,
V
),
(
V
,
1
),
(
64
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
b_h2
+=
tl
.
load
(
p_h0_2
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
K
>
128
:
p_h0_3
=
tl
.
make_block_ptr
(
h0
,
(
K
,
V
),
(
V
,
1
),
(
128
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
b_h3
+=
tl
.
load
(
p_h0_3
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
K
>
192
:
p_h0_4
=
tl
.
make_block_ptr
(
h0
,
(
K
,
V
),
(
V
,
1
),
(
192
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
b_h4
+=
tl
.
load
(
p_h0_4
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
# main recurrence
for
i_t
in
range
(
NT
):
p_h1
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
K
,
V
),
(
V
,
1
),
(
0
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
tl
.
store
(
p_h1
,
b_h1
.
to
(
p_h1
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
64
:
p_h2
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
K
,
V
),
(
V
,
1
),
(
64
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
tl
.
store
(
p_h2
,
b_h2
.
to
(
p_h2
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
128
:
p_h3
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
K
,
V
),
(
V
,
1
),
(
128
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
tl
.
store
(
p_h3
,
b_h3
.
to
(
p_h3
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
192
:
p_h4
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
K
,
V
),
(
V
,
1
),
(
192
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
tl
.
store
(
p_h4
,
b_h4
.
to
(
p_h4
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
p_v
=
tl
.
make_block_ptr
(
v
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
))
p_v_new
=
tl
.
make_block_ptr
(
v_new
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
))
if
SAVE_NEW_VALUE
else
None
b_v_new
=
tl
.
zeros
([
BT
,
BV
],
dtype
=
tl
.
float32
)
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
64
),
(
1
,
0
))
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v_new
+=
tl
.
dot
(
b_w
,
b_h1
.
to
(
b_w
.
dtype
))
if
K
>
64
:
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
64
),
(
BT
,
64
),
(
1
,
0
))
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v_new
+=
tl
.
dot
(
b_w
,
b_h2
.
to
(
b_w
.
dtype
))
if
K
>
128
:
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
128
),
(
BT
,
64
),
(
1
,
0
))
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v_new
+=
tl
.
dot
(
b_w
,
b_h3
.
to
(
b_w
.
dtype
))
if
K
>
192
:
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
192
),
(
BT
,
64
),
(
1
,
0
))
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v_new
+=
tl
.
dot
(
b_w
,
b_h4
.
to
(
b_w
.
dtype
))
b_v_new
=
-
b_v_new
+
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
if
SAVE_NEW_VALUE
:
p_v_new
=
tl
.
make_block_ptr
(
v_new
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
))
tl
.
store
(
p_v_new
,
b_v_new
.
to
(
p_v_new
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
USE_G
:
m_t
=
(
i_t
*
BT
+
tl
.
arange
(
0
,
BT
))
<
T
last_idx
=
min
((
i_t
+
1
)
*
BT
,
T
)
-
1
b_g_last
=
tl
.
load
(
g
+
bos
*
H
+
last_idx
*
H
+
i_h
)
p_g
=
tl
.
make_block_ptr
(
g
+
bos
*
H
+
i_h
,
(
T
,
),
(
H
,
),
(
i_t
*
BT
,
),
(
BT
,
),
(
0
,
))
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,
))
b_v_new
=
b_v_new
*
tl
.
where
(
m_t
,
exp
(
b_g_last
-
b_g
),
0
)[:,
None
]
b_g_last
=
exp
(
b_g_last
)
b_h1
=
b_h1
*
b_g_last
if
K
>
64
:
b_h2
=
b_h2
*
b_g_last
if
K
>
128
:
b_h3
=
b_h3
*
b_g_last
if
K
>
192
:
b_h4
=
b_h4
*
b_g_last
b_v_new
=
b_v_new
.
to
(
k
.
dtype
.
element_ty
)
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
0
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h1
+=
tl
.
dot
(
b_k
,
b_v_new
)
if
K
>
64
:
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
64
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h2
+=
tl
.
dot
(
b_k
,
b_v_new
)
if
K
>
128
:
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
128
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h3
+=
tl
.
dot
(
b_k
,
b_v_new
)
if
K
>
192
:
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
192
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h4
+=
tl
.
dot
(
b_k
,
b_v_new
)
# epilogue
if
STORE_FINAL_STATE
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
K
,
V
),
(
V
,
1
),
(
0
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
tl
.
store
(
p_ht
,
b_h1
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
64
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
K
,
V
),
(
V
,
1
),
(
64
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
tl
.
store
(
p_ht
,
b_h2
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
128
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
K
,
V
),
(
V
,
1
),
(
128
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
tl
.
store
(
p_ht
,
b_h3
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
192
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
K
,
V
),
(
V
,
1
),
(
192
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
tl
.
store
(
p_ht
,
b_h4
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_gated_delta_rule_fwd_h
(
k
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
u
:
torch
.
Tensor
,
g
:
Optional
[
torch
.
Tensor
]
=
None
,
initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
output_final_state
:
bool
=
False
,
chunk_size
:
int
=
64
,
# SY: remove this argument and force chunk size 64?
save_new_value
:
bool
=
True
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
B
,
T
,
Hg
,
K
,
V
=
*
k
.
shape
,
u
.
shape
[
-
1
]
H
=
u
.
shape
[
-
2
]
BT
=
chunk_size
chunk_indices
=
prepare_chunk_indices
(
cu_seqlens
,
chunk_size
)
if
cu_seqlens
is
not
None
else
None
# N: the actual number of sequences in the batch with either equal or variable lengths
if
cu_seqlens
is
None
:
N
,
NT
,
chunk_offsets
=
B
,
triton
.
cdiv
(
T
,
BT
),
None
else
:
N
,
NT
,
chunk_offsets
=
len
(
cu_seqlens
)
-
1
,
len
(
chunk_indices
),
prepare_chunk_offsets
(
cu_seqlens
,
BT
)
assert
K
<=
256
,
"current kernel does not support head dimension larger than 256."
h
=
k
.
new_empty
(
B
,
NT
,
H
,
K
,
V
)
final_state
=
k
.
new_empty
(
N
,
H
,
K
,
V
,
dtype
=
torch
.
float32
)
if
output_final_state
else
None
v_new
=
torch
.
empty_like
(
u
)
if
save_new_value
else
None
def
grid
(
meta
):
return
(
triton
.
cdiv
(
V
,
meta
[
'BV'
]),
N
*
H
)
chunk_gated_delta_rule_fwd_kernel_h_blockdim64
[
grid
](
k
=
k
,
v
=
u
,
w
=
w
,
v_new
=
v_new
,
g
=
g
,
h
=
h
,
h0
=
initial_state
,
ht
=
final_state
,
cu_seqlens
=
cu_seqlens
,
chunk_offsets
=
chunk_offsets
,
T
=
T
,
H
=
H
,
Hg
=
Hg
,
K
=
K
,
V
=
V
,
BT
=
BT
)
return
h
,
v_new
,
final_state
vllm/model_executor/layers/fla/ops/chunk_o.py
0 → 100644
View file @
38d80967
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from
typing
import
Optional
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
.index
import
prepare_chunk_indices
from
.op
import
exp
from
.utils
import
FLA_GDN_FIX_BT
,
check_shared_mem
,
is_nvidia_hopper
BKV_LIST
=
[
64
,
128
]
if
check_shared_mem
()
else
[
32
,
64
]
NUM_WARPS
=
[
2
,
4
]
if
is_nvidia_hopper
else
[
2
,
4
,
8
]
@
triton
.
heuristics
({
'USE_G'
:
lambda
args
:
args
[
'g'
]
is
not
None
,
'IS_VARLEN'
:
lambda
args
:
args
[
'cu_seqlens'
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BK'
:
BK
,
'BV'
:
BV
},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
BK
in
BKV_LIST
for
BV
in
BKV_LIST
for
num_warps
in
NUM_WARPS
for
num_stages
in
[
2
,
3
,
4
]
],
key
=
[
'H'
,
'K'
,
'V'
,
'BT'
],
)
@
triton
.
jit
(
do_not_specialize
=
[
'T'
])
def
chunk_fwd_kernel_o
(
q
,
k
,
v
,
h
,
g
,
o
,
cu_seqlens
,
chunk_indices
,
scale
,
T
,
H
:
tl
.
constexpr
,
Hg
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_G
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_v
,
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_tg
=
i_t
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
NT
=
tl
.
cdiv
(
T
,
BT
)
else
:
NT
=
tl
.
cdiv
(
T
,
BT
)
i_tg
=
i_b
*
NT
+
i_t
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
# offset calculation
q
+=
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
k
+=
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
v
+=
(
bos
*
H
+
i_h
)
*
V
o
+=
(
bos
*
H
+
i_h
)
*
V
h
+=
(
i_tg
*
H
+
i_h
).
to
(
tl
.
int64
)
*
K
*
V
b_o
=
tl
.
zeros
([
BT
,
BV
],
dtype
=
tl
.
float32
)
b_A
=
tl
.
zeros
([
BT
,
BT
],
dtype
=
tl
.
float32
)
for
i_k
in
range
(
tl
.
cdiv
(
K
,
BK
)):
p_q
=
tl
.
make_block_ptr
(
q
,
(
T
,
K
),
(
Hg
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
))
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
Hg
*
K
),
(
i_k
*
BK
,
i_t
*
BT
),
(
BK
,
BT
),
(
0
,
1
))
p_h
=
tl
.
make_block_ptr
(
h
,
(
K
,
V
),
(
V
,
1
),
(
i_k
*
BK
,
i_v
*
BV
),
(
BK
,
BV
),
(
1
,
0
))
# [BT, BK]
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
# [BK, BT]
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
# [BK, BV]
b_h
=
tl
.
load
(
p_h
,
boundary_check
=
(
0
,
1
))
# [BT, BK] @ [BK, BV] -> [BT, BV]
b_o
+=
tl
.
dot
(
b_q
,
b_h
)
# [BT, BK] @ [BK, BT] -> [BT, BT]
b_A
+=
tl
.
dot
(
b_q
,
b_k
)
if
USE_G
:
g
+=
bos
*
H
+
i_h
p_g
=
tl
.
make_block_ptr
(
g
,
(
T
,
),
(
H
,
),
(
i_t
*
BT
,
),
(
BT
,
),
(
0
,
))
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,
))
b_o
=
b_o
*
exp
(
b_g
)[:,
None
]
b_A
=
b_A
*
exp
(
b_g
[:,
None
]
-
b_g
[
None
,
:])
o_t
=
i_t
*
BT
+
tl
.
arange
(
0
,
BT
)
m_t
=
o_t
<
T
m_A
=
(
o_t
[:,
None
]
>=
o_t
[
None
,
:])
&
(
m_t
[:,
None
]
&
m_t
)
b_A
=
tl
.
where
(
m_A
,
b_A
,
0
)
p_v
=
tl
.
make_block_ptr
(
v
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
))
p_o
=
tl
.
make_block_ptr
(
o
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
))
b_v
=
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
# to fix mma -> mma layout conversion
# already solved by triton v3.2 or higher
b_o
=
b_o
*
scale
+
tl
.
dot
(
b_A
.
to
(
b_v
.
dtype
),
b_v
)
*
scale
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_fwd_o
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
g
:
Optional
[
torch
.
Tensor
]
=
None
,
# cumsum of log decay
scale
:
Optional
[
float
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
chunk_size
:
int
=
64
)
->
torch
.
Tensor
:
B
,
T
,
Hg
,
K
,
V
=
*
q
.
shape
,
v
.
shape
[
-
1
]
H
=
v
.
shape
[
-
2
]
if
FLA_GDN_FIX_BT
:
BT
=
64
else
:
BT
=
min
(
chunk_size
,
max
(
16
,
triton
.
next_power_of_2
(
T
)))
chunk_indices
=
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**-
0.5
o
=
torch
.
empty_like
(
v
)
def
grid
(
meta
):
return
(
triton
.
cdiv
(
V
,
meta
[
'BV'
]),
NT
,
B
*
H
)
chunk_fwd_kernel_o
[
grid
](
q
,
k
,
v
,
h
,
g
,
o
,
cu_seqlens
,
chunk_indices
,
scale
,
T
=
T
,
H
=
H
,
Hg
=
Hg
,
K
=
K
,
V
=
V
,
BT
=
BT
,
)
return
o
vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
0 → 100644
View file @
38d80967
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from
typing
import
Optional
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
.index
import
prepare_chunk_indices
from
.op
import
exp
@
triton
.
heuristics
({
'IS_VARLEN'
:
lambda
args
:
args
[
'cu_seqlens'
]
is
not
None
,
'USE_G'
:
lambda
args
:
args
[
'g_cumsum'
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BK'
:
BK
},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
BK
in
[
32
,
64
,
128
]
for
num_warps
in
[
2
,
4
,
8
]
for
num_stages
in
[
2
,
3
,
4
]
],
key
=
[
'H'
,
'K'
,
'BT'
,
'IS_VARLEN'
],
)
@
triton
.
jit
(
do_not_specialize
=
[
'T'
])
def
chunk_scaled_dot_kkt_fwd_kernel
(
k
,
beta
,
g_cumsum
,
A
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
Hg
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
USE_G
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
o_t
=
i_t
*
BT
+
tl
.
arange
(
0
,
BT
)
m_t
=
o_t
<
T
p_beta
=
tl
.
make_block_ptr
(
beta
+
bos
*
H
+
i_h
,
(
T
,
),
(
H
,
),
(
i_t
*
BT
,
),
(
BT
,
),
(
0
,
))
b_beta
=
tl
.
load
(
p_beta
,
boundary_check
=
(
0
,
))
b_A
=
tl
.
zeros
([
BT
,
BT
],
dtype
=
tl
.
float32
)
for
i_k
in
range
(
tl
.
cdiv
(
K
,
BK
)):
p_k
=
tl
.
make_block_ptr
(
k
+
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
,
(
T
,
K
),
(
Hg
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_kb
=
b_k
*
b_beta
[:,
None
]
b_A
+=
tl
.
dot
(
b_kb
.
to
(
b_k
.
dtype
),
tl
.
trans
(
b_k
))
if
USE_G
:
p_g
=
tl
.
make_block_ptr
(
g_cumsum
+
bos
*
H
+
i_h
,
(
T
,
),
(
H
,
),
(
i_t
*
BT
,
),
(
BT
,
),
(
0
,
))
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,
))
b_g_diff
=
b_g
[:,
None
]
-
b_g
[
None
,
:]
b_A
=
b_A
*
exp
(
b_g_diff
)
m_A
=
(
o_t
[:,
None
]
>
o_t
[
None
,
:])
&
(
m_t
[:,
None
]
&
m_t
)
b_A
=
tl
.
where
(
m_A
,
b_A
,
0
)
p_A
=
tl
.
make_block_ptr
(
A
+
(
bos
*
H
+
i_h
)
*
BT
,
(
T
,
BT
),
(
BT
*
H
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BT
),
(
1
,
0
))
tl
.
store
(
p_A
,
b_A
.
to
(
p_A
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_scaled_dot_kkt_fwd
(
k
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
g_cumsum
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
chunk_size
:
int
=
64
,
output_dtype
:
torch
.
dtype
=
torch
.
float32
)
->
torch
.
Tensor
:
r
"""
Compute beta * K * K^T.
Args:
k (torch.Tensor):
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
g_cumsum (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`.
Default: None
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float32`
Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
"""
B
,
T
,
Hg
,
K
=
k
.
shape
H
=
beta
.
shape
[
-
1
]
BT
=
chunk_size
chunk_indices
=
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
A
=
torch
.
empty
(
B
,
T
,
H
,
BT
,
device
=
k
.
device
,
dtype
=
output_dtype
)
chunk_scaled_dot_kkt_fwd_kernel
[(
NT
,
B
*
H
)](
k
=
k
,
beta
=
beta
,
g_cumsum
=
g_cumsum
,
A
=
A
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
Hg
=
Hg
,
K
=
K
,
BT
=
BT
,
)
return
A
vllm/model_executor/layers/fla/ops/cumsum.py
0 → 100644
View file @
38d80967
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import
warnings
from
typing
import
Optional
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
.index
import
prepare_chunk_indices
from
.utils
import
check_shared_mem
,
input_guard
BS_LIST
=
[
32
,
64
]
if
check_shared_mem
()
else
[
16
,
32
]
@
triton
.
heuristics
({
'IS_VARLEN'
:
lambda
args
:
args
[
'cu_seqlens'
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
)
for
num_warps
in
[
1
,
2
,
4
,
8
]
],
key
=
[
'B'
,
'H'
,
'BT'
,
'IS_VARLEN'
,
'REVERSE'
])
@
triton
.
jit
(
do_not_specialize
=
[
'T'
])
def
chunk_local_cumsum_scalar_kernel
(
s
,
o
,
cu_seqlens
,
chunk_indices
,
T
,
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
REVERSE
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
HEAD_FIRST
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
if
HEAD_FIRST
:
p_s
=
tl
.
make_block_ptr
(
s
+
bos
*
H
+
i_h
*
T
,
(
T
,
),
(
1
,
),
(
i_t
*
BT
,
),
(
BT
,
),
(
0
,
))
p_o
=
tl
.
make_block_ptr
(
o
+
bos
*
H
+
i_h
*
T
,
(
T
,
),
(
1
,
),
(
i_t
*
BT
,
),
(
BT
,
),
(
0
,
))
else
:
p_s
=
tl
.
make_block_ptr
(
s
+
bos
*
H
+
i_h
,
(
T
,
),
(
H
,
),
(
i_t
*
BT
,
),
(
BT
,
),
(
0
,
))
p_o
=
tl
.
make_block_ptr
(
o
+
bos
*
H
+
i_h
,
(
T
,
),
(
H
,
),
(
i_t
*
BT
,
),
(
BT
,
),
(
0
,
))
# [BT]
b_s
=
tl
.
load
(
p_s
,
boundary_check
=
(
0
,
)).
to
(
tl
.
float32
)
b_o
=
tl
.
cumsum
(
b_s
,
axis
=
0
)
if
REVERSE
:
b_z
=
tl
.
sum
(
b_s
,
axis
=
0
)
b_o
=
-
b_o
+
b_z
[
None
]
+
b_s
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
))
@
triton
.
heuristics
({
'IS_VARLEN'
:
lambda
args
:
args
[
'cu_seqlens'
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BS'
:
BS
},
num_warps
=
num_warps
)
for
BS
in
BS_LIST
for
num_warps
in
[
2
,
4
,
8
]
],
key
=
[
'B'
,
'H'
,
'S'
,
'BT'
,
'IS_VARLEN'
,
'REVERSE'
])
@
triton
.
jit
(
do_not_specialize
=
[
'T'
])
def
chunk_local_cumsum_vector_kernel
(
s
,
o
,
cu_seqlens
,
chunk_indices
,
T
,
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
S
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
REVERSE
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
HEAD_FIRST
:
tl
.
constexpr
,
):
i_s
,
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
o_i
=
tl
.
arange
(
0
,
BT
)
if
REVERSE
:
m_s
=
tl
.
where
(
o_i
[:,
None
]
<=
o_i
[
None
,
:],
1.
,
0.
)
else
:
m_s
=
tl
.
where
(
o_i
[:,
None
]
>=
o_i
[
None
,
:],
1.
,
0.
)
if
HEAD_FIRST
:
p_s
=
tl
.
make_block_ptr
(
s
+
(
bos
*
H
+
i_h
*
T
)
*
S
,
(
T
,
S
),
(
S
,
1
),
(
i_t
*
BT
,
i_s
*
BS
),
(
BT
,
BS
),
(
1
,
0
))
p_o
=
tl
.
make_block_ptr
(
o
+
(
bos
*
H
+
i_h
*
T
)
*
S
,
(
T
,
S
),
(
S
,
1
),
(
i_t
*
BT
,
i_s
*
BS
),
(
BT
,
BS
),
(
1
,
0
))
else
:
p_s
=
tl
.
make_block_ptr
(
s
+
(
bos
*
H
+
i_h
)
*
S
,
(
T
,
S
),
(
H
*
S
,
1
),
(
i_t
*
BT
,
i_s
*
BS
),
(
BT
,
BS
),
(
1
,
0
))
p_o
=
tl
.
make_block_ptr
(
o
+
(
bos
*
H
+
i_h
)
*
S
,
(
T
,
S
),
(
H
*
S
,
1
),
(
i_t
*
BT
,
i_s
*
BS
),
(
BT
,
BS
),
(
1
,
0
))
# [BT, BS]
b_s
=
tl
.
load
(
p_s
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_o
=
tl
.
dot
(
m_s
,
b_s
,
allow_tf32
=
False
)
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_local_cumsum_scalar
(
g
:
torch
.
Tensor
,
chunk_size
:
int
,
reverse
:
bool
=
False
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
head_first
:
bool
=
False
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
float
)
->
torch
.
Tensor
:
if
head_first
:
B
,
H
,
T
=
g
.
shape
else
:
B
,
T
,
H
=
g
.
shape
assert
chunk_size
==
2
**
(
chunk_size
.
bit_length
()
-
1
),
"chunk_size must be a power of 2"
BT
=
chunk_size
chunk_indices
=
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
g_org
,
g
=
g
,
torch
.
empty_like
(
g
,
dtype
=
output_dtype
or
g
.
dtype
)
grid
=
(
NT
,
B
*
H
)
chunk_local_cumsum_scalar_kernel
[
grid
](
g_org
,
g
,
cu_seqlens
,
chunk_indices
,
T
=
T
,
B
=
B
,
H
=
H
,
BT
=
BT
,
HEAD_FIRST
=
head_first
,
REVERSE
=
reverse
)
return
g
def
chunk_local_cumsum_vector
(
g
:
torch
.
Tensor
,
chunk_size
:
int
,
reverse
:
bool
=
False
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
head_first
:
bool
=
False
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
float
)
->
torch
.
Tensor
:
if
head_first
:
B
,
H
,
T
,
S
=
g
.
shape
else
:
B
,
T
,
H
,
S
=
g
.
shape
BT
=
chunk_size
chunk_indices
=
prepare_chunk_indices
(
cu_seqlens
,
chunk_size
)
if
cu_seqlens
is
not
None
else
None
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
assert
chunk_size
==
2
**
(
chunk_size
.
bit_length
()
-
1
),
"chunk_size must be a power of 2"
g_org
,
g
=
g
,
torch
.
empty_like
(
g
,
dtype
=
output_dtype
or
g
.
dtype
)
def
grid
(
meta
):
return
(
triton
.
cdiv
(
meta
[
'S'
],
meta
[
'BS'
]),
NT
,
B
*
H
)
# keep cumulative normalizer in fp32
# this kernel is equivalent to
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
chunk_local_cumsum_vector_kernel
[
grid
](
g_org
,
g
,
cu_seqlens
,
chunk_indices
,
T
=
T
,
B
=
B
,
H
=
H
,
S
=
S
,
BT
=
BT
,
HEAD_FIRST
=
head_first
,
REVERSE
=
reverse
)
return
g
@
input_guard
def
chunk_local_cumsum
(
g
:
torch
.
Tensor
,
chunk_size
:
int
,
reverse
:
bool
=
False
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
head_first
:
bool
=
False
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
float
,
**
kwargs
)
->
torch
.
Tensor
:
if
not
head_first
and
g
.
shape
[
1
]
<
g
.
shape
[
2
]:
warnings
.
warn
(
f
"Input tensor shape suggests potential format mismatch: seq_len (
{
g
.
shape
[
1
]
}
) < num_heads (
{
g
.
shape
[
2
]
}
). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
,
stacklevel
=
2
)
if
cu_seqlens
is
not
None
:
assert
g
.
shape
[
0
]
==
1
,
"Only batch size 1 is supported when cu_seqlens are provided"
if
len
(
g
.
shape
)
==
3
:
return
chunk_local_cumsum_scalar
(
g
,
chunk_size
,
reverse
,
cu_seqlens
,
head_first
,
output_dtype
)
elif
len
(
g
.
shape
)
==
4
:
return
chunk_local_cumsum_vector
(
g
,
chunk_size
,
reverse
,
cu_seqlens
,
head_first
,
output_dtype
)
else
:
raise
ValueError
(
f
"Unsupported input shape
{
g
.
shape
}
. "
f
"which should be (B, T, H, D) if `head_first=False` "
f
"or (B, H, T, D) otherwise"
)
vllm/model_executor/layers/fla/ops/fused_recurrent.py
0 → 100644
View file @
38d80967
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from
typing
import
Optional
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
.op
import
exp
@
triton
.
heuristics
({
'USE_INITIAL_STATE'
:
lambda
args
:
args
[
'h0'
]
is
not
None
,
'IS_VARLEN'
:
lambda
args
:
args
[
'cu_seqlens'
]
is
not
None
,
"IS_CONTINUOUS_BATCHING"
:
lambda
args
:
args
[
'ssm_state_indices'
]
is
not
None
,
"IS_SPEC_DECODING"
:
lambda
args
:
args
[
'num_accepted_tokens'
]
is
not
None
,
})
@
triton
.
jit
(
do_not_specialize
=
[
'N'
,
'T'
])
def
fused_recurrent_gated_delta_rule_fwd_kernel
(
q
,
k
,
v
,
g
,
beta
,
o
,
h0
,
ht
,
cu_seqlens
,
ssm_state_indices
,
num_accepted_tokens
,
scale
,
N
:
tl
.
constexpr
,
# num of sequences
T
:
tl
.
constexpr
,
# num of tokens
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
HV
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
stride_init_state_token
:
tl
.
constexpr
,
stride_final_state_token
:
tl
.
constexpr
,
stride_indices_seq
:
tl
.
constexpr
,
stride_indices_tok
:
tl
.
constexpr
,
USE_INITIAL_STATE
:
tl
.
constexpr
,
# whether to use initial state
INPLACE_FINAL_STATE
:
tl
.
constexpr
,
# whether to store final state inplace
IS_BETA_HEADWISE
:
tl
.
constexpr
,
# whether beta is headwise vector or scalar,
USE_QK_L2NORM_IN_KERNEL
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
IS_CONTINUOUS_BATCHING
:
tl
.
constexpr
,
IS_SPEC_DECODING
:
tl
.
constexpr
,
):
i_k
,
i_v
,
i_nh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_n
,
i_hv
=
i_nh
//
HV
,
i_nh
%
HV
i_h
=
i_hv
//
(
HV
//
H
)
if
IS_VARLEN
:
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int64
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int64
)
all
=
T
T
=
eos
-
bos
else
:
bos
,
eos
=
i_n
*
T
,
i_n
*
T
+
T
all
=
B
*
T
if
T
==
0
:
# no tokens to process for this sequence
return
o_k
=
i_k
*
BK
+
tl
.
arange
(
0
,
BK
)
o_v
=
i_v
*
BV
+
tl
.
arange
(
0
,
BV
)
p_q
=
q
+
(
bos
*
H
+
i_h
)
*
K
+
o_k
p_k
=
k
+
(
bos
*
H
+
i_h
)
*
K
+
o_k
p_v
=
v
+
(
bos
*
HV
+
i_hv
)
*
V
+
o_v
if
IS_BETA_HEADWISE
:
p_beta
=
beta
+
(
bos
*
HV
+
i_hv
)
*
V
+
o_v
else
:
p_beta
=
beta
+
bos
*
HV
+
i_hv
p_g
=
g
+
bos
*
HV
+
i_hv
p_o
=
o
+
((
i_k
*
all
+
bos
)
*
HV
+
i_hv
)
*
V
+
o_v
mask_k
=
o_k
<
K
mask_v
=
o_v
<
V
mask_h
=
mask_k
[:,
None
]
&
mask_v
[
None
,
:]
b_h
=
tl
.
zeros
([
BK
,
BV
],
dtype
=
tl
.
float32
)
if
USE_INITIAL_STATE
:
if
IS_CONTINUOUS_BATCHING
:
if
IS_SPEC_DECODING
:
i_t
=
tl
.
load
(
num_accepted_tokens
+
i_n
).
to
(
tl
.
int64
)
-
1
else
:
i_t
=
0
p_h0
=
h0
+
tl
.
load
(
ssm_state_indices
+
i_n
*
stride_indices_seq
+
i_t
).
to
(
tl
.
int64
)
*
stride_init_state_token
else
:
p_h0
=
h0
+
bos
*
HV
*
K
*
V
p_h0
=
p_h0
+
i_hv
*
K
*
V
+
o_k
[:,
None
]
*
V
+
o_v
[
None
,
:]
b_h
+=
tl
.
load
(
p_h0
,
mask
=
mask_h
,
other
=
0
).
to
(
tl
.
float32
)
for
i_t
in
range
(
0
,
T
):
b_q
=
tl
.
load
(
p_q
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_k
=
tl
.
load
(
p_k
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_v
=
tl
.
load
(
p_v
,
mask
=
mask_v
,
other
=
0
).
to
(
tl
.
float32
)
b_g
=
tl
.
load
(
p_g
).
to
(
tl
.
float32
)
if
USE_QK_L2NORM_IN_KERNEL
:
b_q
=
b_q
/
tl
.
sqrt
(
tl
.
sum
(
b_q
*
b_q
)
+
1e-6
)
b_k
=
b_k
/
tl
.
sqrt
(
tl
.
sum
(
b_k
*
b_k
)
+
1e-6
)
b_q
=
b_q
*
scale
# [BK, BV]
b_h
*=
exp
(
b_g
)
# [BV]
b_v
-=
tl
.
sum
(
b_h
*
b_k
[:,
None
],
0
)
if
IS_BETA_HEADWISE
:
b_beta
=
tl
.
load
(
p_beta
,
mask
=
mask_v
,
other
=
0
).
to
(
tl
.
float32
)
else
:
b_beta
=
tl
.
load
(
p_beta
).
to
(
tl
.
float32
)
b_v
*=
b_beta
# [BK, BV]
b_h
+=
b_k
[:,
None
]
*
b_v
[
None
,
:]
# [BV]
b_o
=
tl
.
sum
(
b_h
*
b_q
[:,
None
],
0
)
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
mask
=
mask_v
)
# keep the states for multi-query tokens
if
INPLACE_FINAL_STATE
:
p_ht
=
ht
+
tl
.
load
(
ssm_state_indices
+
i_n
*
stride_indices_seq
+
i_t
).
to
(
tl
.
int64
)
*
stride_final_state_token
else
:
p_ht
=
ht
+
(
bos
+
i_t
)
*
stride_final_state_token
p_ht
=
p_ht
+
i_hv
*
K
*
V
+
o_k
[:,
None
]
*
V
+
o_v
[
None
,
:]
tl
.
store
(
p_ht
,
b_h
.
to
(
p_ht
.
dtype
.
element_ty
),
mask
=
mask_h
)
p_q
+=
H
*
K
p_k
+=
H
*
K
p_o
+=
HV
*
V
p_v
+=
HV
*
V
p_g
+=
HV
p_beta
+=
HV
*
(
V
if
IS_BETA_HEADWISE
else
1
)
def
fused_recurrent_gated_delta_rule_fwd
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
,
initial_state
:
torch
.
Tensor
,
inplace_final_state
:
bool
=
True
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
ssm_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
num_accepted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
B
,
T
,
H
,
K
,
V
=
*
k
.
shape
,
v
.
shape
[
-
1
]
HV
=
v
.
shape
[
2
]
N
=
B
if
cu_seqlens
is
None
else
len
(
cu_seqlens
)
-
1
BK
,
BV
=
triton
.
next_power_of_2
(
K
),
min
(
triton
.
next_power_of_2
(
V
),
8
)
NK
,
NV
=
triton
.
cdiv
(
K
,
BK
),
triton
.
cdiv
(
V
,
BV
)
assert
NK
==
1
,
"NK > 1 is not supported yet"
num_stages
=
3
num_warps
=
1
o
=
q
.
new_empty
(
NK
,
*
v
.
shape
)
if
inplace_final_state
:
final_state
=
initial_state
else
:
final_state
=
q
.
new_empty
(
T
,
HV
,
K
,
V
,
dtype
=
initial_state
.
dtype
)
stride_init_state_token
=
initial_state
.
stride
(
0
)
stride_final_state_token
=
final_state
.
stride
(
0
)
if
ssm_state_indices
is
None
:
stride_indices_seq
,
stride_indices_tok
=
1
,
1
elif
ssm_state_indices
.
ndim
==
1
:
stride_indices_seq
,
stride_indices_tok
=
ssm_state_indices
.
stride
(
0
),
1
else
:
stride_indices_seq
,
stride_indices_tok
=
ssm_state_indices
.
stride
()
grid
=
(
NK
,
NV
,
N
*
HV
)
fused_recurrent_gated_delta_rule_fwd_kernel
[
grid
](
q
=
q
,
k
=
k
,
v
=
v
,
g
=
g
,
beta
=
beta
,
o
=
o
,
h0
=
initial_state
,
ht
=
final_state
,
cu_seqlens
=
cu_seqlens
,
ssm_state_indices
=
ssm_state_indices
,
num_accepted_tokens
=
num_accepted_tokens
,
scale
=
scale
,
N
=
N
,
T
=
T
,
B
=
B
,
H
=
H
,
HV
=
HV
,
K
=
K
,
V
=
V
,
BK
=
BK
,
BV
=
BV
,
stride_init_state_token
=
stride_init_state_token
,
stride_final_state_token
=
stride_final_state_token
,
stride_indices_seq
=
stride_indices_seq
,
stride_indices_tok
=
stride_indices_tok
,
IS_BETA_HEADWISE
=
beta
.
ndim
==
v
.
ndim
,
USE_QK_L2NORM_IN_KERNEL
=
use_qk_l2norm_in_kernel
,
INPLACE_FINAL_STATE
=
inplace_final_state
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
o
=
o
.
squeeze
(
0
)
return
o
,
final_state
class
FusedRecurrentFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
,
initial_state
:
torch
.
Tensor
,
inplace_final_state
:
bool
=
True
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
ssm_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
num_accepted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_qk_l2norm_in_kernel
:
bool
=
False
):
o
,
final_state
=
fused_recurrent_gated_delta_rule_fwd
(
q
=
q
.
contiguous
(),
k
=
k
.
contiguous
(),
v
=
v
.
contiguous
(),
g
=
g
.
contiguous
(),
beta
=
beta
.
contiguous
(),
scale
=
scale
,
initial_state
=
initial_state
,
inplace_final_state
=
inplace_final_state
,
cu_seqlens
=
cu_seqlens
,
ssm_state_indices
=
ssm_state_indices
,
num_accepted_tokens
=
num_accepted_tokens
,
use_qk_l2norm_in_kernel
=
use_qk_l2norm_in_kernel
,
)
return
o
,
final_state
def
fused_recurrent_gated_delta_rule
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
=
None
,
scale
:
float
=
None
,
initial_state
:
torch
.
Tensor
=
None
,
inplace_final_state
:
bool
=
True
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
ssm_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
num_accepted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
r
"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]`.
v (torch.Tensor):
values of shape `[B, T, HV, V]`.
GVA is applied if `HV > H`.
g (torch.Tensor):
g (decays) of shape `[B, T, HV]`.
beta (torch.Tensor):
betas of shape `[B, T, HV]`.
scale (Optional[int]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
inplace_final_state: bool:
Whether to store the final state in-place to save memory.
Default: `True`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
ssm_state_indices (Optional[torch.Tensor]):
Indices to map the input sequences to the initial/final states.
num_accepted_tokens (Optional[torch.Tensor]):
Number of accepted tokens for each sequence during decoding.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HV, V]`.
final_state (torch.Tensor):
Final state of shape `[N, HV, K, V]`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
# inputs with equal lengths
>>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
>>> q = torch.randn(B, T, H, K, device='cuda')
>>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
>>> v = torch.randn(B, T, HV, V, device='cuda')
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
>>> h0 = torch.randn(B, HV, K, V, device='cuda')
>>> o, ht = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,
cu_seqlens=cu_seqlens
)
"""
if
cu_seqlens
is
not
None
and
q
.
shape
[
0
]
!=
1
:
raise
ValueError
(
f
"The batch size is expected to be 1 rather than
{
q
.
shape
[
0
]
}
when using `cu_seqlens`."
f
"Please flatten variable-length inputs before processing."
)
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**-
0.5
else
:
assert
scale
>
0
,
"scale must be positive"
if
beta
is
None
:
beta
=
torch
.
ones_like
(
q
[...,
0
])
o
,
final_state
=
FusedRecurrentFunction
.
apply
(
q
,
k
,
v
,
g
,
beta
,
scale
,
initial_state
,
inplace_final_state
,
cu_seqlens
,
ssm_state_indices
,
num_accepted_tokens
,
use_qk_l2norm_in_kernel
,
)
return
o
,
final_state
vllm/model_executor/layers/fla/ops/index.py
0 → 100644
View file @
38d80967
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import
torch
from
vllm.triton_utils
import
triton
from
.utils
import
tensor_cache
@
tensor_cache
def
prepare_lens
(
cu_seqlens
:
torch
.
LongTensor
)
->
torch
.
LongTensor
:
return
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
@
tensor_cache
def
prepare_chunk_indices
(
cu_seqlens
:
torch
.
LongTensor
,
chunk_size
:
int
)
->
torch
.
LongTensor
:
indices
=
torch
.
cat
([
torch
.
arange
(
n
)
for
n
in
triton
.
cdiv
(
prepare_lens
(
cu_seqlens
),
chunk_size
).
tolist
()
])
return
torch
.
stack
([
indices
.
eq
(
0
).
cumsum
(
0
)
-
1
,
indices
],
1
).
to
(
cu_seqlens
)
@
tensor_cache
def
prepare_chunk_offsets
(
cu_seqlens
:
torch
.
LongTensor
,
chunk_size
:
int
)
->
torch
.
LongTensor
:
return
torch
.
cat
([
cu_seqlens
.
new_tensor
([
0
]),
triton
.
cdiv
(
prepare_lens
(
cu_seqlens
),
chunk_size
)
]).
cumsum
(
-
1
)
vllm/model_executor/layers/fla/ops/l2norm.py
0 → 100644
View file @
38d80967
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import
os
from
typing
import
Optional
import
torch
from
vllm.triton_utils
import
tl
,
triton
BT_LIST
=
[
8
,
16
,
32
,
64
,
128
]
USE_DEFAULT_FLA_NORM
=
int
(
os
.
getenv
(
"USE_DEFAULT_FLA_NORM"
,
"0"
))
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
)
for
num_warps
in
[
1
,
2
,
4
,
8
,
16
,
32
]
],
key
=
[
'D'
])
@
triton
.
jit
def
l2norm_fwd_kernel1
(
x
,
y
,
D
,
BD
:
tl
.
constexpr
,
eps
,
):
i_t
=
tl
.
program_id
(
0
)
x
+=
i_t
*
D
y
+=
i_t
*
D
# Compute mean and variance
cols
=
tl
.
arange
(
0
,
BD
)
mask
=
cols
<
D
b_x
=
tl
.
load
(
x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
b_var
=
tl
.
sum
(
b_x
*
b_x
,
axis
=
0
)
b_rstd
=
1
/
tl
.
sqrt
(
b_var
+
eps
)
# tl.store(Rstd + i_t, rstd)
# Normalize and apply linear transformation
b_y
=
b_x
*
b_rstd
tl
.
store
(
y
+
cols
,
b_y
,
mask
=
mask
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BT'
:
BT
},
num_warps
=
num_warps
)
for
num_warps
in
[
1
,
2
,
4
,
8
,
16
]
for
BT
in
BT_LIST
],
key
=
[
'D'
])
@
triton
.
jit
(
do_not_specialize
=
[
"NB"
])
def
l2norm_fwd_kernel
(
x
,
y
,
eps
,
NB
,
T
,
D
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BD
:
tl
.
constexpr
,
):
i_t
=
tl
.
program_id
(
0
)
p_x
=
tl
.
make_block_ptr
(
x
,
(
T
,
D
),
(
D
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BD
),
(
1
,
0
))
b_x
=
tl
.
load
(
p_x
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_var
=
tl
.
sum
(
b_x
*
b_x
,
axis
=
1
)
b_y
=
b_x
/
tl
.
sqrt
(
b_var
+
eps
)[:,
None
]
p_y
=
tl
.
make_block_ptr
(
y
,
(
T
,
D
),
(
D
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BD
),
(
1
,
0
))
tl
.
store
(
p_y
,
b_y
.
to
(
p_y
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
@
triton
.
jit
def
l2norm_fwd_kernel2
(
X
,
Y
,
eps
,
M
,
N
:
tl
.
constexpr
,
MBLOCK
:
tl
.
constexpr
):
xoffset
=
tl
.
program_id
(
0
)
*
MBLOCK
row_idx
=
xoffset
+
tl
.
arange
(
0
,
MBLOCK
)[:,
None
]
xmask
=
row_idx
<
M
rindex
=
tl
.
arange
(
0
,
N
)[
None
,
:]
xs
=
tl
.
load
(
X
+
(
rindex
+
N
*
row_idx
),
xmask
).
to
(
tl
.
float32
)
square
=
tl
.
broadcast_to
(
xs
*
xs
,
[
MBLOCK
,
N
])
square_sum
=
tl
.
sum
(
tl
.
where
(
xmask
,
square
,
0
),
1
)[:,
None
]
rsqrt
=
tl
.
rsqrt
(
square_sum
+
eps
)
tl
.
store
(
Y
+
(
rindex
+
N
*
row_idx
),
xs
*
rsqrt
,
xmask
)
def
l2norm_fwd
(
x
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
None
):
x_shape_og
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
# allocate output
if
output_dtype
is
None
:
y
=
torch
.
empty_like
(
x
)
else
:
y
=
torch
.
empty_like
(
x
,
dtype
=
output_dtype
)
assert
y
.
stride
(
-
1
)
==
1
T
,
D
=
x
.
shape
[
0
],
x
.
shape
[
-
1
]
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BD
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
D
))
if
D
>
BD
:
raise
RuntimeError
(
"This layer doesn't support feature dim >= 64KB."
)
if
not
USE_DEFAULT_FLA_NORM
:
MBLOCK
=
32
# M, N = x.shape
l2norm_fwd_kernel2
[(
triton
.
cdiv
(
T
,
MBLOCK
),
)](
x
,
y
,
eps
,
T
,
D
,
MBLOCK
,
)
else
:
if
D
<=
512
:
NB
=
triton
.
cdiv
(
T
,
2048
)
def
grid
(
meta
):
return
(
triton
.
cdiv
(
T
,
meta
[
'BT'
]),
)
l2norm_fwd_kernel
[
grid
](
x
,
y
,
eps
,
NB
=
NB
,
T
=
T
,
D
=
D
,
BD
=
BD
,
)
else
:
l2norm_fwd_kernel1
[(
T
,
)](
x
,
y
,
eps
=
eps
,
D
=
D
,
BD
=
BD
,
)
return
y
.
view
(
x_shape_og
)
vllm/model_executor/layers/fla/ops/layernorm_guard.py
0 → 100644
View file @
38d80967
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Tri Dao
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2024, Tri Dao.
# ruff: noqa: E501
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
vllm.triton_utils
import
tl
,
triton
from
.utils
import
input_guard
def
rms_norm_ref
(
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
,
upcast
=
True
):
dtype
=
x
.
dtype
weight
=
weight
.
float
()
bias
=
bias
.
float
()
if
bias
is
not
None
else
None
if
upcast
:
x
=
x
.
float
()
z
=
z
.
float
()
if
z
is
not
None
else
z
if
z
is
not
None
and
not
norm_before_gate
:
x
=
x
*
F
.
silu
(
z
)
if
group_size
is
None
:
rstd
=
1
/
torch
.
sqrt
((
x
.
square
()).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
eps
)
out
=
(
x
*
rstd
*
weight
)
+
bias
if
bias
is
not
None
else
(
x
*
rstd
*
weight
)
else
:
x_group
=
rearrange
(
x
,
"... (g d) -> ... g d"
,
d
=
group_size
)
rstd
=
1
/
torch
.
sqrt
((
x_group
.
square
()).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
eps
)
out
=
rearrange
(
x_group
*
rstd
,
"... g d -> ... (g d)"
)
*
weight
if
bias
is
not
None
:
out
=
out
+
bias
if
z
is
not
None
and
norm_before_gate
:
out
*=
F
.
silu
(
z
)
return
out
.
to
(
dtype
)
@
triton
.
heuristics
({
"HAS_BIAS"
:
lambda
args
:
args
[
"B"
]
is
not
None
,
"HAS_Z"
:
lambda
args
:
args
[
"Z"
]
is
not
None
,
})
@
triton
.
jit
def
layer_norm_fwd_kernel
(
X
,
# pointer to the input
Y
,
# pointer to the output
W
,
# pointer to the weights
B
,
# pointer to the biases
Z
,
# pointer to the other branch
Mean
,
# pointer to the mean
Rstd
,
# pointer to the 1/std
stride_x_row
,
# how much to increase the pointer when moving by 1 row
stride_y_row
,
stride_z_row
,
M
,
# number of rows in X
N
,
# number of columns in X
eps
,
# epsilon to avoid division by zero
BLOCK_N
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
HAS_Z
:
tl
.
constexpr
,
NORM_BEFORE_GATE
:
tl
.
constexpr
,
IS_RMS_NORM
:
tl
.
constexpr
,
):
# Map the program id to the row of X and Y it should compute.
row
=
tl
.
program_id
(
0
)
group
=
tl
.
program_id
(
1
)
X
+=
row
*
stride_x_row
+
group
*
N
Y
+=
row
*
stride_y_row
+
group
*
N
if
HAS_Z
:
Z
+=
row
*
stride_z_row
+
group
*
N
if
not
IS_RMS_NORM
:
Mean
+=
group
*
M
Rstd
+=
group
*
M
W
+=
group
*
N
if
HAS_BIAS
:
B
+=
group
*
N
# Compute mean and variance
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
x
=
tl
.
load
(
X
+
cols
,
mask
=
cols
<
N
,
other
=
0.
).
to
(
tl
.
float32
)
if
HAS_Z
and
not
NORM_BEFORE_GATE
:
z
=
tl
.
load
(
Z
+
cols
,
mask
=
cols
<
N
).
to
(
tl
.
float32
)
x
*=
z
*
tl
.
sigmoid
(
z
)
if
not
IS_RMS_NORM
:
mean
=
tl
.
sum
(
x
,
axis
=
0
)
/
N
tl
.
store
(
Mean
+
row
,
mean
)
xbar
=
tl
.
where
(
cols
<
N
,
x
-
mean
,
0.
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
else
:
xbar
=
tl
.
where
(
cols
<
N
,
x
,
0.
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
rstd
=
1
/
tl
.
sqrt
(
var
+
eps
)
tl
.
store
(
Rstd
+
row
,
rstd
)
# Normalize and apply linear transformation
mask
=
cols
<
N
w
=
tl
.
load
(
W
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
HAS_BIAS
:
b
=
tl
.
load
(
B
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
x_hat
=
(
x
-
mean
)
*
rstd
if
not
IS_RMS_NORM
else
x
*
rstd
y
=
x_hat
*
w
+
b
if
HAS_BIAS
else
x_hat
*
w
if
HAS_Z
and
NORM_BEFORE_GATE
:
z
=
tl
.
load
(
Z
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
y
*=
z
*
tl
.
sigmoid
(
z
)
# Write output
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
def
layer_norm_fwd
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
eps
:
float
,
z
:
torch
.
Tensor
=
None
,
out
:
torch
.
Tensor
=
None
,
group_size
:
int
=
None
,
norm_before_gate
:
bool
=
True
,
is_rms_norm
:
bool
=
False
,
):
M
,
N
=
x
.
shape
if
group_size
is
None
:
group_size
=
N
assert
N
%
group_size
==
0
ngroups
=
N
//
group_size
assert
x
.
stride
(
-
1
)
==
1
if
z
is
not
None
:
assert
z
.
stride
(
-
1
)
==
1
assert
z
.
shape
==
(
M
,
N
)
assert
weight
.
shape
==
(
N
,
)
assert
weight
.
stride
(
-
1
)
==
1
if
bias
is
not
None
:
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
shape
==
(
N
,
)
# allocate output
if
out
is
not
None
:
assert
out
.
shape
==
x
.
shape
else
:
out
=
torch
.
empty_like
(
x
)
assert
out
.
stride
(
-
1
)
==
1
mean
=
torch
.
empty
((
ngroups
*
M
,
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
if
not
is_rms_norm
else
None
rstd
=
torch
.
empty
((
ngroups
*
M
,
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BLOCK_N
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
group_size
))
if
group_size
>
BLOCK_N
:
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK_N
//
256
,
1
),
8
)
grid
=
(
M
,
ngroups
)
layer_norm_fwd_kernel
[
grid
](
x
,
out
,
weight
,
bias
,
z
,
mean
,
rstd
,
x
.
stride
(
0
),
out
.
stride
(
0
),
z
.
stride
(
0
)
if
z
is
not
None
else
0
,
M
,
group_size
,
eps
,
BLOCK_N
=
BLOCK_N
,
NORM_BEFORE_GATE
=
norm_before_gate
,
IS_RMS_NORM
=
is_rms_norm
,
num_warps
=
num_warps
)
return
out
,
mean
,
rstd
class
LayerNormFn
(
torch
.
autograd
.
Function
):
@
input_guard
@
staticmethod
def
forward
(
ctx
,
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
,
is_rms_norm
=
False
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
x_shape_og
=
x
.
shape
# reshape input data into 2D tensor
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
if
z
is
not
None
:
assert
z
.
shape
==
x_shape_og
z
=
z
.
reshape
(
-
1
,
z
.
shape
[
-
1
])
if
z
.
stride
(
-
1
)
!=
1
:
z
=
z
.
contiguous
()
weight
=
weight
.
contiguous
()
if
bias
is
not
None
:
bias
=
bias
.
contiguous
()
y
,
mean
,
rstd
=
layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
z
=
z
,
group_size
=
group_size
,
norm_before_gate
=
norm_before_gate
,
is_rms_norm
=
is_rms_norm
,
)
ctx
.
save_for_backward
(
x
,
weight
,
bias
,
mean
,
rstd
,
z
)
ctx
.
x_shape_og
=
x_shape_og
ctx
.
eps
=
eps
ctx
.
group_size
=
group_size
ctx
.
norm_before_gate
=
norm_before_gate
ctx
.
is_rms_norm
=
is_rms_norm
return
y
.
reshape
(
x_shape_og
)
def
layernorm_fn
(
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
,
is_rms_norm
=
False
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
is_rms_norm
)
def
rmsnorm_fn
(
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
True
)
class
LayerNormGated
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
:
float
=
1e-5
,
group_size
:
Optional
[
int
]
=
None
,
norm_before_gate
:
bool
=
True
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
group_size
=
group_size
self
.
norm_before_gate
=
norm_before_gate
self
.
reset_parameters
()
def
reset_parameters
(
self
):
torch
.
nn
.
init
.
ones_
(
self
.
weight
)
torch
.
nn
.
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
x
,
z
=
None
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
return
layernorm_fn
(
x
,
self
.
weight
,
self
.
bias
,
z
=
z
,
group_size
=
self
.
group_size
,
eps
=
self
.
eps
,
norm_before_gate
=
self
.
norm_before_gate
)
class
RMSNormGated
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
:
float
=
1e-5
,
group_size
:
Optional
[
int
]
=
None
,
norm_before_gate
:
bool
=
False
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
"bias"
,
None
)
self
.
group_size
=
group_size
self
.
norm_before_gate
=
norm_before_gate
self
.
reset_parameters
()
def
reset_parameters
(
self
):
torch
.
nn
.
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
x
,
z
=
None
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
return
rmsnorm_fn
(
x
,
self
.
weight
,
self
.
bias
,
z
=
z
,
eps
=
self
.
eps
,
group_size
=
self
.
group_size
,
norm_before_gate
=
self
.
norm_before_gate
)
vllm/model_executor/layers/fla/ops/op.py
0 → 100644
View file @
38d80967
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import
os
from
vllm.triton_utils
import
tl
,
tldevice
,
triton
if
os
.
environ
.
get
(
'FLA_USE_FAST_OPS'
,
'0'
)
==
'1'
:
div
=
tldevice
.
fast_dividef
exp
=
tldevice
.
fast_expf
log
=
tldevice
.
fast_logf
log2
=
tldevice
.
fast_log2f
else
:
@
triton
.
jit
def
div_normal
(
x
,
y
):
return
x
/
y
div
=
div_normal
exp
=
tl
.
exp
log
=
tl
.
log
log2
=
tl
.
log2
if
not
hasattr
(
tl
,
'gather'
):
@
triton
.
jit
def
gather
(
src
,
index
,
axis
,
_builder
=
None
):
# This is a fallback implementation when tl.gather is not supported
# In order to pass triton compiler, there is no actual gather operation
return
src
else
:
gather
=
tl
.
gather
vllm/model_executor/layers/fla/ops/solve_tril.py
0 → 100644
View file @
38d80967
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from
typing
import
Optional
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
.index
import
prepare_chunk_indices
from
.utils
import
input_guard
@
triton
.
heuristics
({
'IS_VARLEN'
:
lambda
args
:
args
[
'cu_seqlens'
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
,
8
]
for
num_stages
in
[
2
,
3
,
4
,
5
]
],
key
=
[
'BT'
],
)
@
triton
.
jit
(
do_not_specialize
=
[
'T'
])
def
solve_tril_16x16_kernel
(
A
,
Ad
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
A
=
A
+
(
bos
*
H
+
i_h
)
*
BT
Ad
=
Ad
+
(
bos
*
H
+
i_h
)
*
16
offset
=
(
i_t
*
16
)
%
BT
p_A
=
tl
.
make_block_ptr
(
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
16
,
offset
),
(
16
,
16
),
(
1
,
0
))
p_Ai
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
16
,
0
),
(
16
,
16
),
(
1
,
0
))
b_A
=
tl
.
load
(
p_A
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_A
=
-
tl
.
where
(
tl
.
arange
(
0
,
16
)[:,
None
]
>
tl
.
arange
(
0
,
16
)[
None
,
:],
b_A
,
0
)
o_i
=
tl
.
arange
(
0
,
16
)
for
i
in
range
(
1
,
min
(
16
,
T
-
i_t
*
16
)):
b_a
=
-
tl
.
load
(
A
+
(
i_t
*
16
+
i
)
*
H
*
BT
+
o_i
+
offset
)
b_a
=
b_a
+
tl
.
sum
(
b_a
[:,
None
]
*
b_A
,
0
)
mask
=
o_i
==
i
b_A
=
tl
.
where
(
mask
[:,
None
],
b_a
,
b_A
)
b_A
+=
o_i
[:,
None
]
==
o_i
[
None
,
:]
tl
.
store
(
p_Ai
,
b_A
.
to
(
p_Ai
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
@
triton
.
heuristics
({
'IS_VARLEN'
:
lambda
args
:
args
[
'cu_seqlens'
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
,
8
]
for
num_stages
in
[
2
,
3
,
4
,
5
]
],
key
=
[
'H'
,
'BT'
,
'IS_VARLEN'
],
)
@
triton
.
jit
(
do_not_specialize
=
[
'T'
])
def
merge_16x16_to_32x32_inverse_kernel
(
A
,
Ad
,
Ai
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
A
+=
(
bos
*
H
+
i_h
)
*
32
Ad
+=
(
bos
*
H
+
i_h
)
*
16
Ai
+=
(
bos
*
H
+
i_h
)
*
32
p_A_21
=
tl
.
make_block_ptr
(
A
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
+
16
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ad_11
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
32
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ad_22
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
32
+
16
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ai_11
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ai_22
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
+
16
,
16
),
(
16
,
16
),
(
1
,
0
))
p_Ai_21
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
+
16
,
0
),
(
16
,
16
),
(
1
,
0
))
A_21
=
tl
.
load
(
p_A_21
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_11
=
tl
.
load
(
p_Ad_11
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_22
=
tl
.
load
(
p_Ad_22
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_21
=
-
tl
.
dot
(
tl
.
dot
(
Ai_22
,
A_21
,
input_precision
=
'ieee'
),
Ai_11
,
input_precision
=
'ieee'
)
tl
.
store
(
p_Ai_11
,
Ai_11
.
to
(
p_Ai_11
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_22
,
Ai_22
.
to
(
p_Ai_22
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_21
,
Ai_21
.
to
(
p_Ai_21
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
@
triton
.
heuristics
({
'IS_VARLEN'
:
lambda
args
:
args
[
'cu_seqlens'
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
2
,
4
,
8
]
for
num_stages
in
[
2
,
3
,
4
,
5
]
],
key
=
[
'H'
,
'BT'
,
'IS_VARLEN'
],
)
@
triton
.
jit
(
do_not_specialize
=
[
'T'
])
def
merge_16x16_to_64x64_inverse_kernel
(
A
,
Ad
,
Ai
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
A
+=
(
bos
*
H
+
i_h
)
*
64
Ad
+=
(
bos
*
H
+
i_h
)
*
16
Ai
+=
(
bos
*
H
+
i_h
)
*
64
p_A_21
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
0
),
(
16
,
16
),
(
1
,
0
))
p_A_32
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
16
),
(
16
,
16
),
(
1
,
0
))
p_A_31
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
0
),
(
16
,
16
),
(
1
,
0
))
p_A_43
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
32
),
(
16
,
16
),
(
1
,
0
))
p_A_42
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
16
),
(
16
,
16
),
(
1
,
0
))
p_A_41
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ad_11
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ad_22
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
+
16
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ad_33
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
+
32
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ad_44
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
+
48
,
0
),
(
16
,
16
),
(
1
,
0
))
A_21
=
tl
.
load
(
p_A_21
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_32
=
tl
.
load
(
p_A_32
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_31
=
tl
.
load
(
p_A_31
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_43
=
tl
.
load
(
p_A_43
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_42
=
tl
.
load
(
p_A_42
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_41
=
tl
.
load
(
p_A_41
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_11
=
tl
.
load
(
p_Ad_11
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_22
=
tl
.
load
(
p_Ad_22
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_33
=
tl
.
load
(
p_Ad_33
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_44
=
tl
.
load
(
p_Ad_44
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_21
=
-
tl
.
dot
(
tl
.
dot
(
Ai_22
,
A_21
,
input_precision
=
'ieee'
),
Ai_11
,
input_precision
=
'ieee'
)
Ai_32
=
-
tl
.
dot
(
tl
.
dot
(
Ai_33
,
A_32
,
input_precision
=
'ieee'
),
Ai_22
,
input_precision
=
'ieee'
)
Ai_43
=
-
tl
.
dot
(
tl
.
dot
(
Ai_44
,
A_43
,
input_precision
=
'ieee'
),
Ai_33
,
input_precision
=
'ieee'
)
Ai_31
=
-
tl
.
dot
(
Ai_33
,
tl
.
dot
(
A_31
,
Ai_11
,
input_precision
=
'ieee'
)
+
tl
.
dot
(
A_32
,
Ai_21
,
input_precision
=
'ieee'
),
input_precision
=
'ieee'
)
Ai_42
=
-
tl
.
dot
(
Ai_44
,
tl
.
dot
(
A_42
,
Ai_22
,
input_precision
=
'ieee'
)
+
tl
.
dot
(
A_43
,
Ai_32
,
input_precision
=
'ieee'
),
input_precision
=
'ieee'
)
Ai_41
=
-
tl
.
dot
(
Ai_44
,
tl
.
dot
(
A_41
,
Ai_11
,
input_precision
=
'ieee'
)
+
tl
.
dot
(
A_42
,
Ai_21
,
input_precision
=
'ieee'
)
+
tl
.
dot
(
A_43
,
Ai_31
,
input_precision
=
'ieee'
),
input_precision
=
'ieee'
)
p_Ai_11
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ai_22
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
16
),
(
16
,
16
),
(
1
,
0
))
p_Ai_33
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
32
),
(
16
,
16
),
(
1
,
0
))
p_Ai_44
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
48
),
(
16
,
16
),
(
1
,
0
))
p_Ai_21
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ai_31
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ai_32
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
16
),
(
16
,
16
),
(
1
,
0
))
p_Ai_41
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ai_42
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
16
),
(
16
,
16
),
(
1
,
0
))
p_Ai_43
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
32
),
(
16
,
16
),
(
1
,
0
))
tl
.
store
(
p_Ai_11
,
Ai_11
.
to
(
p_Ai_11
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_22
,
Ai_22
.
to
(
p_Ai_22
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_33
,
Ai_33
.
to
(
p_Ai_33
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_44
,
Ai_44
.
to
(
p_Ai_44
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_21
,
Ai_21
.
to
(
p_Ai_21
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_31
,
Ai_31
.
to
(
p_Ai_31
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_32
,
Ai_32
.
to
(
p_Ai_32
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_41
,
Ai_41
.
to
(
p_Ai_41
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_42
,
Ai_42
.
to
(
p_Ai_42
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_43
,
Ai_43
.
to
(
p_Ai_43
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
fill_zeros
=
tl
.
zeros
((
16
,
16
),
dtype
=
tl
.
float32
)
p_Ai_12
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
16
),
(
16
,
16
),
(
1
,
0
))
p_Ai_13
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
32
),
(
16
,
16
),
(
1
,
0
))
p_Ai_14
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
48
),
(
16
,
16
),
(
1
,
0
))
p_Ai_23
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
32
),
(
16
,
16
),
(
1
,
0
))
p_Ai_24
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
48
),
(
16
,
16
),
(
1
,
0
))
p_Ai_34
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
48
),
(
16
,
16
),
(
1
,
0
))
tl
.
store
(
p_Ai_12
,
fill_zeros
.
to
(
p_Ai_12
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_13
,
fill_zeros
.
to
(
p_Ai_13
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_14
,
fill_zeros
.
to
(
p_Ai_14
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_23
,
fill_zeros
.
to
(
p_Ai_23
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_24
,
fill_zeros
.
to
(
p_Ai_24
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_34
,
fill_zeros
.
to
(
p_Ai_34
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
@
input_guard
def
solve_tril
(
A
:
torch
.
Tensor
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
output_dtype
:
torch
.
dtype
=
torch
.
float
)
->
torch
.
Tensor
:
"""
Compute the inverse of the lower triangular matrix
A should be strictly lower triangular, i.e., A.triu() == 0.
Args:
A (torch.Tensor):
[B, T, H, K]
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor.
Default: None.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float`
Returns:
(I + A)^-1 with the same shape as A
"""
assert
A
.
shape
[
-
1
]
in
[
16
,
32
,
64
]
B
,
T
,
H
,
BT
=
A
.
shape
Ad
=
torch
.
empty
(
B
,
T
,
H
,
16
,
device
=
A
.
device
,
dtype
=
torch
.
float
if
BT
!=
16
else
output_dtype
)
chunk_indices
=
prepare_chunk_indices
(
cu_seqlens
,
16
)
if
cu_seqlens
is
not
None
else
None
NT
=
len
(
chunk_indices
)
if
cu_seqlens
is
not
None
else
triton
.
cdiv
(
T
,
16
)
solve_tril_16x16_kernel
[
NT
,
B
*
H
](
A
=
A
,
Ad
=
Ad
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
BT
=
BT
,
)
if
BT
==
16
:
return
Ad
Ai
=
torch
.
empty
(
B
,
T
,
H
,
BT
,
device
=
A
.
device
,
dtype
=
output_dtype
)
merge_fn
=
merge_16x16_to_32x32_inverse_kernel
if
BT
==
32
else
merge_16x16_to_64x64_inverse_kernel
chunk_indices
=
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
NT
=
len
(
chunk_indices
)
if
cu_seqlens
is
not
None
else
triton
.
cdiv
(
T
,
BT
)
merge_fn
[
NT
,
B
*
H
](
A
=
A
,
Ad
=
Ad
,
Ai
=
Ai
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
BT
=
BT
,
)
return
Ai
Prev
1
…
21
22
23
24
25
26
27
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment