Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fdd9daaf
Unverified
Commit
fdd9daaf
authored
Aug 29, 2024
by
Mor Zusman
Committed by
GitHub
Aug 28, 2024
Browse files
[Kernel/Model] Migrate mamba_ssm and causal_conv1d kernels to vLLM (#7651)
parent
8c56e57d
Changes
20
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2815 additions
and
31 deletions
+2815
-31
CMakeLists.txt
CMakeLists.txt
+2
-0
Dockerfile
Dockerfile
+0
-23
csrc/mamba/causal_conv1d/causal_conv1d.cu
csrc/mamba/causal_conv1d/causal_conv1d.cu
+700
-0
csrc/mamba/causal_conv1d/causal_conv1d.h
csrc/mamba/causal_conv1d/causal_conv1d.h
+144
-0
csrc/mamba/causal_conv1d/static_switch.h
csrc/mamba/causal_conv1d/static_switch.h
+28
-0
csrc/mamba/mamba_ssm/selective_scan.h
csrc/mamba/mamba_ssm/selective_scan.h
+276
-0
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
+593
-0
csrc/mamba/mamba_ssm/static_switch.h
csrc/mamba/mamba_ssm/static_switch.h
+28
-0
csrc/ops.h
csrc/ops.h
+22
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+25
-0
requirements-mamba.txt
requirements-mamba.txt
+0
-3
requirements-test.txt
requirements-test.txt
+1
-1
tests/kernels/test_causal_conv1d.py
tests/kernels/test_causal_conv1d.py
+205
-0
tests/kernels/test_mamba_ssm.py
tests/kernels/test_mamba_ssm.py
+324
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+30
-0
vllm/model_executor/layers/mamba/__init__.py
vllm/model_executor/layers/mamba/__init__.py
+0
-0
vllm/model_executor/layers/mamba/ops/__init__.py
vllm/model_executor/layers/mamba/ops/__init__.py
+0
-0
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
+86
-0
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+346
-0
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+5
-4
No files found.
CMakeLists.txt
View file @
fdd9daaf
...
@@ -203,6 +203,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -203,6 +203,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_MakeAvailable
(
cutlass
)
FetchContent_MakeAvailable
(
cutlass
)
list
(
APPEND VLLM_EXT_SRC
list
(
APPEND VLLM_EXT_SRC
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
...
...
Dockerfile
View file @
fdd9daaf
...
@@ -42,9 +42,6 @@ COPY requirements-cuda.txt requirements-cuda.txt
...
@@ -42,9 +42,6 @@ COPY requirements-cuda.txt requirements-cuda.txt
RUN
--mount
=
type
=
cache,target
=
/root/.cache/pip
\
RUN
--mount
=
type
=
cache,target
=
/root/.cache/pip
\
python3
-m
pip
install
-r
requirements-cuda.txt
python3
-m
pip
install
-r
requirements-cuda.txt
COPY
requirements-mamba.txt requirements-mamba.txt
RUN
python3
-m
pip
install
packaging
RUN
python3
-m
pip
install
-r
requirements-mamba.txt
# cuda arch list used by torch
# cuda arch list used by torch
# can be useful for both `dev` and `test`
# can be useful for both `dev` and `test`
...
@@ -127,22 +124,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
...
@@ -127,22 +124,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
python3
-m
pip
install
-r
requirements-dev.txt
python3
-m
pip
install
-r
requirements-dev.txt
#################### DEV IMAGE ####################
#################### DEV IMAGE ####################
#################### MAMBA Build IMAGE ####################
FROM
dev as mamba-builder
# max jobs used for build
ARG
max_jobs=2
ENV
MAX_JOBS=${max_jobs}
WORKDIR
/usr/src/mamba
COPY
requirements-mamba.txt requirements-mamba.txt
# Download the wheel or build it if a pre-compiled release doesn't exist
RUN
pip
--verbose
wheel
-r
requirements-mamba.txt
\
--no-build-isolation
--no-deps
--no-cache-dir
#################### MAMBA Build IMAGE ####################
#################### vLLM installation IMAGE ####################
#################### vLLM installation IMAGE ####################
# image with vLLM installed
# image with vLLM installed
FROM
nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04 AS vllm-base
FROM
nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04 AS vllm-base
...
@@ -179,10 +160,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
...
@@ -179,10 +160,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
--mount
=
type
=
cache,target
=
/root/.cache/pip
\
--mount
=
type
=
cache,target
=
/root/.cache/pip
\
python3
-m
pip
install
dist/
*
.whl
--verbose
python3
-m
pip
install
dist/
*
.whl
--verbose
RUN
--mount
=
type
=
bind
,from
=
mamba-builder,src
=
/usr/src/mamba,target
=
/usr/src/mamba
\
--mount
=
type
=
cache,target
=
/root/.cache/pip
\
python3
-m
pip
install
/usr/src/mamba/
*
.whl
--no-cache-dir
RUN
--mount
=
type
=
cache,target
=
/root/.cache/pip
\
RUN
--mount
=
type
=
cache,target
=
/root/.cache/pip
\
.
/etc/environment
&&
\
.
/etc/environment
&&
\
python3
-m
pip
install
https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp
${
PYTHON_VERSION_STR
}
-cp
${
PYTHON_VERSION_STR
}
-linux_x86_64
.whl
python3
-m
pip
install
https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp
${
PYTHON_VERSION_STR
}
-cp
${
PYTHON_VERSION_STR
}
-linux_x86_64
.whl
...
...
csrc/mamba/causal_conv1d/causal_conv1d.cu
0 → 100644
View file @
fdd9daaf
This diff is collapsed.
Click to expand it.
csrc/mamba/causal_conv1d/causal_conv1d.h
0 → 100644
View file @
fdd9daaf
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
ConvParamsBase
{
using
index_t
=
uint32_t
;
int
batch
,
dim
,
seqlen
,
width
;
bool
silu_activation
;
index_t
x_batch_stride
;
index_t
x_c_stride
;
index_t
x_l_stride
;
index_t
weight_c_stride
;
index_t
weight_width_stride
;
index_t
out_batch_stride
;
index_t
out_c_stride
;
index_t
out_l_stride
;
index_t
conv_state_batch_stride
;
index_t
conv_state_c_stride
;
index_t
conv_state_l_stride
;
// Common data pointers.
void
*
__restrict__
x_ptr
;
void
*
__restrict__
weight_ptr
;
void
*
__restrict__
bias_ptr
;
void
*
__restrict__
out_ptr
;
void
*
__restrict__
conv_state_ptr
;
void
*
__restrict__
seq_idx_ptr
;
// No __restrict__ since initial_states could be the same as final_states.
void
*
initial_states_ptr
;
index_t
initial_states_batch_stride
;
index_t
initial_states_l_stride
;
index_t
initial_states_c_stride
;
void
*
final_states_ptr
;
index_t
final_states_batch_stride
;
index_t
final_states_l_stride
;
index_t
final_states_c_stride
;
};
#ifndef USE_ROCM
#include <cuda_bf16.h>
template
<
typename
T
>
__device__
inline
T
shuffle_xor
(
T
val
,
int
offset
)
{
return
__shfl_xor_sync
(
uint32_t
(
-
1
),
val
,
offset
);
}
constexpr
size_t
custom_max
(
std
::
initializer_list
<
size_t
>
ilist
)
{
return
std
::
max
(
ilist
);
}
template
<
typename
T
>
constexpr
T
constexpr_min
(
T
a
,
T
b
)
{
return
std
::
min
(
a
,
b
);
}
#else
#include <hip/hip_bf16.h>
template
<
typename
T
>
__device__
inline
T
shuffle_xor
(
T
val
,
int
offset
)
{
return
__shfl_xor
(
val
,
offset
);
}
constexpr
size_t
custom_max
(
std
::
initializer_list
<
size_t
>
ilist
)
{
return
*
std
::
max_element
(
ilist
.
begin
(),
ilist
.
end
());
}
template
<
typename
T
>
constexpr
T
constexpr_min
(
T
a
,
T
b
)
{
return
a
<
b
?
a
:
b
;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
BYTES
>
struct
BytesToType
{};
template
<
>
struct
BytesToType
<
16
>
{
using
Type
=
uint4
;
static_assert
(
sizeof
(
Type
)
==
16
);
};
template
<
>
struct
BytesToType
<
8
>
{
using
Type
=
uint64_t
;
static_assert
(
sizeof
(
Type
)
==
8
);
};
template
<
>
struct
BytesToType
<
4
>
{
using
Type
=
uint32_t
;
static_assert
(
sizeof
(
Type
)
==
4
);
};
template
<
>
struct
BytesToType
<
2
>
{
using
Type
=
uint16_t
;
static_assert
(
sizeof
(
Type
)
==
2
);
};
template
<
>
struct
BytesToType
<
1
>
{
using
Type
=
uint8_t
;
static_assert
(
sizeof
(
Type
)
==
1
);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
SumOp
{
__device__
inline
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
+
y
;
}
};
template
<
int
THREADS
>
struct
Allreduce
{
static_assert
(
THREADS
==
32
||
THREADS
==
16
||
THREADS
==
8
||
THREADS
==
4
);
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
constexpr
int
OFFSET
=
THREADS
/
2
;
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
OFFSET
));
return
Allreduce
<
OFFSET
>::
run
(
x
,
op
);
}
};
template
<
>
struct
Allreduce
<
2
>
{
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
1
));
return
x
;
}
};
csrc/mamba/causal_conv1d/static_switch.h
0 → 100644
View file @
fdd9daaf
// Inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
static constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
static constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
csrc/mamba/mamba_ssm/selective_scan.h
0 → 100644
View file @
fdd9daaf
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// clang-format off
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.h
#pragma once
#ifndef USE_ROCM
#include <cuda_bf16.h>
#else
#include <hip/hip_bf16.h>
#endif
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
SSMParamsBase
{
using
index_t
=
uint32_t
;
int
batch
,
dim
,
seqlen
,
dstate
,
n_groups
,
n_chunks
;
int
dim_ngroups_ratio
;
bool
is_variable_B
;
bool
is_variable_C
;
bool
delta_softplus
;
index_t
A_d_stride
;
index_t
A_dstate_stride
;
index_t
B_batch_stride
;
index_t
B_d_stride
;
index_t
B_dstate_stride
;
index_t
B_group_stride
;
index_t
C_batch_stride
;
index_t
C_d_stride
;
index_t
C_dstate_stride
;
index_t
C_group_stride
;
index_t
u_batch_stride
;
index_t
u_d_stride
;
index_t
delta_batch_stride
;
index_t
delta_d_stride
;
index_t
z_batch_stride
;
index_t
z_d_stride
;
index_t
out_batch_stride
;
index_t
out_d_stride
;
index_t
out_z_batch_stride
;
index_t
out_z_d_stride
;
// Common data pointers.
void
*
__restrict__
A_ptr
;
void
*
__restrict__
B_ptr
;
void
*
__restrict__
C_ptr
;
void
*
__restrict__
D_ptr
;
void
*
__restrict__
u_ptr
;
void
*
__restrict__
delta_ptr
;
void
*
__restrict__
delta_bias_ptr
;
void
*
__restrict__
out_ptr
;
void
*
__restrict__
x_ptr
;
void
*
__restrict__
z_ptr
;
void
*
__restrict__
out_z_ptr
;
void
*
__restrict__
index_ptr
;
};
#ifndef USE_ROCM
constexpr
size_t
custom_max
(
std
::
initializer_list
<
size_t
>
ilist
)
{
return
std
::
max
(
ilist
);
}
template
<
typename
T
>
constexpr
T
constexpr_min
(
T
a
,
T
b
)
{
return
std
::
min
(
a
,
b
);
}
#else
constexpr
size_t
custom_max
(
std
::
initializer_list
<
size_t
>
ilist
)
{
return
*
std
::
max_element
(
ilist
.
begin
(),
ilist
.
end
());
}
template
<
typename
T
>
constexpr
T
constexpr_min
(
T
a
,
T
b
)
{
return
a
<
b
?
a
:
b
;
}
#endif
#define MAX_DSTATE 256
inline
__device__
float2
operator
+
(
const
float2
&
a
,
const
float2
&
b
){
return
{
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
};
}
inline
__device__
float3
operator
+
(
const
float3
&
a
,
const
float3
&
b
)
{
return
{
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
,
a
.
z
+
b
.
z
};
}
inline
__device__
float4
operator
+
(
const
float4
&
a
,
const
float4
&
b
){
return
{
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
,
a
.
z
+
b
.
z
,
a
.
w
+
b
.
w
};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
BYTES
>
struct
BytesToType
{};
template
<
>
struct
BytesToType
<
16
>
{
using
Type
=
uint4
;
static_assert
(
sizeof
(
Type
)
==
16
);
};
template
<
>
struct
BytesToType
<
8
>
{
using
Type
=
uint64_t
;
static_assert
(
sizeof
(
Type
)
==
8
);
};
template
<
>
struct
BytesToType
<
4
>
{
using
Type
=
uint32_t
;
static_assert
(
sizeof
(
Type
)
==
4
);
};
template
<
>
struct
BytesToType
<
2
>
{
using
Type
=
uint16_t
;
static_assert
(
sizeof
(
Type
)
==
2
);
};
template
<
>
struct
BytesToType
<
1
>
{
using
Type
=
uint8_t
;
static_assert
(
sizeof
(
Type
)
==
1
);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
scalar_t
,
int
N
>
struct
Converter
{
static
inline
__device__
void
to_float
(
const
scalar_t
(
&
src
)[
N
],
float
(
&
dst
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
dst
[
i
]
=
src
[
i
];
}
}
};
template
<
int
N
>
struct
Converter
<
at
::
Half
,
N
>
{
static
inline
__device__
void
to_float
(
const
at
::
Half
(
&
src
)[
N
],
float
(
&
dst
)[
N
])
{
static_assert
(
N
%
2
==
0
);
auto
&
src2
=
reinterpret_cast
<
const
half2
(
&
)[
N
/
2
]
>
(
src
);
auto
&
dst2
=
reinterpret_cast
<
float2
(
&
)[
N
/
2
]
>
(
dst
);
#pragma unroll
for
(
int
i
=
0
;
i
<
N
/
2
;
++
i
)
{
dst2
[
i
]
=
__half22float2
(
src2
[
i
]);
}
}
};
#if __CUDA_ARCH__ >= 800
template
<
int
N
>
struct
Converter
<
at
::
BFloat16
,
N
>
{
static
inline
__device__
void
to_float
(
const
at
::
BFloat16
(
&
src
)[
N
],
float
(
&
dst
)[
N
])
{
static_assert
(
N
%
2
==
0
);
auto
&
src2
=
reinterpret_cast
<
const
nv_bfloat162
(
&
)[
N
/
2
]
>
(
src
);
auto
&
dst2
=
reinterpret_cast
<
float2
(
&
)[
N
/
2
]
>
(
dst
);
#pragma unroll
for
(
int
i
=
0
;
i
<
N
/
2
;
++
i
)
{
dst2
[
i
]
=
__bfloat1622float2
(
src2
[
i
]);
}
}
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
scalar_t
>
struct
SSMScanOp
;
template
<
>
struct
SSMScanOp
<
float
>
{
__device__
__forceinline__
float2
operator
()(
const
float2
&
ab0
,
const
float2
&
ab1
)
const
{
return
make_float2
(
ab1
.
x
*
ab0
.
x
,
ab1
.
x
*
ab0
.
y
+
ab1
.
y
);
}
};
// A stateful callback functor that maintains a running prefix to be applied
// during consecutive scan operations.
template
<
typename
scalar_t
>
struct
SSMScanPrefixCallbackOp
{
using
scan_t
=
std
::
conditional_t
<
std
::
is_same_v
<
scalar_t
,
float
>
,
float2
,
float4
>
;
scan_t
running_prefix
;
// Constructor
__device__
SSMScanPrefixCallbackOp
(
scan_t
running_prefix_
)
:
running_prefix
(
running_prefix_
)
{}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
__device__
scan_t
operator
()(
scan_t
block_aggregate
)
{
scan_t
old_prefix
=
running_prefix
;
running_prefix
=
SSMScanOp
<
scalar_t
>
()(
running_prefix
,
block_aggregate
);
return
old_prefix
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Ktraits
>
inline
__device__
void
load_input
(
typename
Ktraits
::
input_t
*
u
,
typename
Ktraits
::
input_t
(
&
u_vals
)[
Ktraits
::
kNItems
],
typename
Ktraits
::
BlockLoadT
::
TempStorage
&
smem_load
,
int
seqlen
)
{
if
constexpr
(
Ktraits
::
kIsEvenLen
)
{
auto
&
smem_load_vec
=
reinterpret_cast
<
typename
Ktraits
::
BlockLoadVecT
::
TempStorage
&>
(
smem_load
);
using
vec_t
=
typename
Ktraits
::
vec_t
;
typename
Ktraits
::
BlockLoadVecT
(
smem_load_vec
).
Load
(
reinterpret_cast
<
vec_t
*>
(
u
),
reinterpret_cast
<
vec_t
(
&
)[
Ktraits
::
kNLoads
]
>
(
u_vals
)
#ifdef USE_ROCM
,
Ktraits
::
kNThreads
*
Ktraits
::
kNLoads
#endif
);
}
else
{
typename
Ktraits
::
BlockLoadT
(
smem_load
).
Load
(
u
,
u_vals
,
seqlen
,
0.
f
);
}
}
template
<
typename
Ktraits
>
inline
__device__
void
load_index
(
int
*
u
,
int
(
&
u_vals
)[
Ktraits
::
kNItems
],
typename
Ktraits
::
BlockLoadIndexT
::
TempStorage
&
smem_load_index
,
int
seqlen
)
{
if
constexpr
(
Ktraits
::
kIsEvenLen
)
{
auto
&
smem_load_index_vec
=
reinterpret_cast
<
typename
Ktraits
::
BlockLoadIndexVecT
::
TempStorage
&>
(
smem_load_index
);
Ktraits
::
BlockLoadIndexVecT
(
smem_load_index_vec
).
Load
(
reinterpret_cast
<
uint4
*>
(
u
),
reinterpret_cast
<
uint4
(
&
)[
Ktraits
::
kNLoadsIndex
]
>
(
u_vals
)
);
}
else
{
Ktraits
::
BlockLoadIndexT
(
smem_load_index
).
Load
(
u
,
u_vals
,
seqlen
,
0
);
}
}
template
<
typename
Ktraits
>
inline
__device__
void
load_weight
(
typename
Ktraits
::
input_t
*
Bvar
,
typename
Ktraits
::
weight_t
(
&
B_vals
)[
Ktraits
::
kNItems
],
typename
Ktraits
::
BlockLoadWeightT
::
TempStorage
&
smem_load_weight
,
int
seqlen
)
{
constexpr
int
kNItems
=
Ktraits
::
kNItems
;
typename
Ktraits
::
input_t
B_vals_load
[
kNItems
];
if
constexpr
(
Ktraits
::
kIsEvenLen
)
{
auto
&
smem_load_weight_vec
=
reinterpret_cast
<
typename
Ktraits
::
BlockLoadWeightVecT
::
TempStorage
&>
(
smem_load_weight
);
using
vec_t
=
typename
Ktraits
::
vec_t
;
typename
Ktraits
::
BlockLoadWeightVecT
(
smem_load_weight_vec
).
Load
(
reinterpret_cast
<
vec_t
*>
(
Bvar
),
reinterpret_cast
<
vec_t
(
&
)[
Ktraits
::
kNLoads
]
>
(
B_vals_load
)
);
}
else
{
typename
Ktraits
::
BlockLoadWeightT
(
smem_load_weight
).
Load
(
Bvar
,
B_vals_load
,
seqlen
,
0.
f
);
}
// #pragma unroll
// for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
Converter
<
typename
Ktraits
::
input_t
,
kNItems
>::
to_float
(
B_vals_load
,
B_vals
);
}
template
<
typename
Ktraits
>
inline
__device__
void
store_output
(
typename
Ktraits
::
input_t
*
out
,
const
float
(
&
out_vals
)[
Ktraits
::
kNItems
],
typename
Ktraits
::
BlockStoreT
::
TempStorage
&
smem_store
,
int
seqlen
)
{
typename
Ktraits
::
input_t
write_vals
[
Ktraits
::
kNItems
];
#pragma unroll
for
(
int
i
=
0
;
i
<
Ktraits
::
kNItems
;
++
i
)
{
write_vals
[
i
]
=
out_vals
[
i
];
}
if
constexpr
(
Ktraits
::
kIsEvenLen
)
{
auto
&
smem_store_vec
=
reinterpret_cast
<
typename
Ktraits
::
BlockStoreVecT
::
TempStorage
&>
(
smem_store
);
using
vec_t
=
typename
Ktraits
::
vec_t
;
typename
Ktraits
::
BlockStoreVecT
(
smem_store_vec
).
Store
(
reinterpret_cast
<
vec_t
*>
(
out
),
reinterpret_cast
<
vec_t
(
&
)[
Ktraits
::
kNLoads
]
>
(
write_vals
)
);
}
else
{
typename
Ktraits
::
BlockStoreT
(
smem_store
).
Store
(
out
,
write_vals
,
seqlen
);
}
}
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
0 → 100644
View file @
fdd9daaf
This diff is collapsed.
Click to expand it.
csrc/mamba/mamba_ssm/static_switch.h
0 → 100644
View file @
fdd9daaf
// Inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
// clang-format off
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/static_switch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
csrc/ops.h
View file @
fdd9daaf
...
@@ -195,6 +195,28 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -195,6 +195,28 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
torch
::
Tensor
num_tokens_post_pad
);
std
::
vector
<
torch
::
Tensor
>
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
const
torch
::
Tensor
&
C
,
const
c10
::
optional
<
torch
::
Tensor
>&
D_
,
const
c10
::
optional
<
torch
::
Tensor
>&
z_
,
const
c10
::
optional
<
torch
::
Tensor
>&
delta_bias_
,
bool
delta_softplus
,
const
c10
::
optional
<
torch
::
Tensor
>&
index_
,
const
c10
::
optional
<
torch
::
Tensor
>&
x
);
at
::
Tensor
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
bool
silu_activation
);
at
::
Tensor
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>&
seq_idx_
,
const
c10
::
optional
<
at
::
Tensor
>&
initial_states_
,
const
c10
::
optional
<
at
::
Tensor
>&
final_states_out_
,
bool
silu_activation
);
#ifndef USE_ROCM
#ifndef USE_ROCM
using
fptr_t
=
int64_t
;
using
fptr_t
=
int64_t
;
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
...
...
csrc/torch_bindings.cpp
View file @
fdd9daaf
...
@@ -202,6 +202,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -202,6 +202,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"cutlass_scaled_mm_supports_fp8"
,
&
cutlass_scaled_mm_supports_fp8
);
ops
.
def
(
"cutlass_scaled_mm_supports_fp8"
,
&
cutlass_scaled_mm_supports_fp8
);
ops
.
impl
(
"cutlass_scaled_mm_supports_fp8"
,
torch
::
kCUDA
,
ops
.
impl
(
"cutlass_scaled_mm_supports_fp8"
,
torch
::
kCUDA
,
&
cutlass_scaled_mm_supports_fp8
);
&
cutlass_scaled_mm_supports_fp8
);
// Mamba selective scan kernel
ops
.
def
(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? index_, Tensor? x) -> Tensor[]"
);
ops
.
impl
(
"selective_scan_fwd"
,
torch
::
kCUDA
,
&
selective_scan_fwd
);
ops
.
def
(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias_,"
"bool silu_activation) -> Tensor"
);
ops
.
impl
(
"causal_conv1d_update"
,
torch
::
kCUDA
,
&
causal_conv1d_update
);
ops
.
def
(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor? seq_idx_,"
"Tensor? initial_states_,"
"Tensor? final_states_out_,"
"bool silu_activation) -> Tensor"
);
ops
.
impl
(
"causal_conv1d_fwd"
,
torch
::
kCUDA
,
&
causal_conv1d_fwd
);
#endif
#endif
// Quantized GEMM for GPTQ.
// Quantized GEMM for GPTQ.
...
...
requirements-mamba.txt
deleted
100644 → 0
View file @
8c56e57d
# Mamba dependencies
mamba-ssm>=1.2.2
causal-conv1d>=1.2.0
requirements-test.txt
View file @
fdd9daaf
...
@@ -11,7 +11,7 @@ pytest-shard
...
@@ -11,7 +11,7 @@ pytest-shard
# testing utils
# testing utils
awscli
awscli
einops # required for MPT
and
qwen-vl
einops # required for MPT
,
qwen-vl
and Mamba
httpx
httpx
peft
peft
requests
requests
...
...
tests/kernels/test_causal_conv1d.py
0 → 100644
View file @
fdd9daaf
from
typing
import
Optional
import
pytest
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_update
)
def
causal_conv1d_ref
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
initial_states
:
Optional
[
torch
.
Tensor
]
=
None
,
return_final_states
:
bool
=
False
,
final_states_out
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
"silu"
,
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)
out: (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
dtype_in
=
x
.
dtype
x
=
x
.
to
(
weight
.
dtype
)
seqlen
=
x
.
shape
[
-
1
]
dim
,
width
=
weight
.
shape
if
initial_states
is
None
:
out
=
F
.
conv1d
(
x
,
weight
.
unsqueeze
(
1
),
bias
,
padding
=
width
-
1
,
groups
=
dim
)
else
:
x
=
torch
.
cat
([
initial_states
,
x
],
dim
=-
1
)
out
=
F
.
conv1d
(
x
,
weight
.
unsqueeze
(
1
),
bias
,
padding
=
0
,
groups
=
dim
)
out
=
out
[...,
:
seqlen
]
if
return_final_states
:
final_states
=
F
.
pad
(
x
,
(
width
-
1
-
x
.
shape
[
-
1
],
0
)).
to
(
dtype_in
)
# (batch, dim, width - 1)
if
final_states_out
is
not
None
:
final_states_out
.
copy_
(
final_states
)
else
:
final_states_out
=
final_states
out
=
(
out
if
activation
is
None
else
F
.
silu
(
out
)).
to
(
dtype
=
dtype_in
)
return
(
out
,
None
)
if
not
return_final_states
else
(
out
,
final_states_out
)
def
causal_conv1d_update_ref
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
None
):
"""
x: (batch, dim)
conv_state: (batch, dim, width)
weight: (dim, width)
bias: (dim,)
out: (batch, dim)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
dtype_in
=
x
.
dtype
batch
,
dim
=
x
.
shape
width
=
weight
.
shape
[
1
]
assert
conv_state
.
shape
==
(
batch
,
dim
,
width
)
assert
weight
.
shape
==
(
dim
,
width
)
conv_state
.
copy_
(
torch
.
roll
(
conv_state
,
shifts
=-
1
,
dims
=-
1
))
# Update state (B D W)
conv_state
[:,
:,
-
1
]
=
x
out
=
torch
.
sum
(
conv_state
*
weight
,
dim
=-
1
)
# (B D)
if
bias
is
not
None
:
out
+=
bias
return
(
out
if
activation
is
None
else
F
.
silu
(
out
)).
to
(
dtype
=
dtype_in
)
@
pytest
.
mark
.
parametrize
(
"return_final_states"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_initial_states"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"channel_last"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
128
,
512
,
4096
])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
64
,
4096
+
32
])
@
pytest
.
mark
.
parametrize
(
'batch'
,
[
1
,
2
])
def
test_causal_conv1d
(
batch
,
dim
,
seqlen
,
width
,
has_bias
,
silu_activation
,
itype
,
channel_last
,
has_initial_states
,
return_final_states
):
if
not
channel_last
and
(
has_initial_states
or
return_final_states
):
pytest
.
skip
(
"Only channel_last support initial_states or return_final_states"
)
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
# set seed
torch
.
random
.
manual_seed
(
0
)
if
not
channel_last
:
x
=
torch
.
randn
(
batch
,
4096
+
dim
+
64
,
seqlen
,
device
=
device
,
dtype
=
itype
)[:,
4096
:
4096
+
dim
,
:]
else
:
x
=
rearrange
(
torch
.
randn
(
batch
,
seqlen
,
4096
+
dim
+
64
,
device
=
device
,
dtype
=
itype
)[:,
:,
4096
:
4096
+
dim
],
"b s d -> b d s"
)
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
if
has_initial_states
:
initial_states
=
torch
.
randn
(
batch
,
width
-
1
,
dim
,
device
=
device
,
dtype
=
itype
).
transpose
(
1
,
2
)
else
:
initial_states
=
None
x_ref
=
x
.
detach
().
clone
()
weight_ref
=
weight
.
detach
().
clone
()
bias_ref
=
bias
.
detach
().
clone
()
if
bias
is
not
None
else
None
initial_states_ref
=
initial_states
.
detach
().
clone
(
)
if
initial_states
is
not
None
else
None
activation
=
None
if
not
silu_activation
else
"silu"
out
,
final_states
=
causal_conv1d_fn
(
x
,
weight
,
bias
,
initial_states
=
initial_states
,
return_final_states
=
return_final_states
,
activation
=
activation
)
out_ref
,
final_states_ref
=
causal_conv1d_ref
(
x_ref
,
weight_ref
,
bias_ref
,
initial_states
=
initial_states_ref
,
return_final_states
=
return_final_states
,
activation
=
activation
)
if
return_final_states
:
assert
final_states
is
not
None
and
final_states_ref
is
not
None
assert
torch
.
allclose
(
final_states
,
final_states_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
if
return_final_states
:
out
+=
F
.
sigmoid
(
final_states
).
sum
(
dim
=-
1
,
keepdim
=
True
)
out_ref
+=
F
.
sigmoid
(
final_states_ref
).
sum
(
dim
=-
1
,
keepdim
=
True
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
2
,
3
,
4
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
@
pytest
.
mark
.
parametrize
(
"batch"
,
[
1
,
2
])
def
test_causal_conv1d_update
(
batch
,
dim
,
width
,
has_bias
,
silu_activation
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
# set seed
torch
.
random
.
manual_seed
(
0
)
batch
=
2
x
=
torch
.
randn
(
batch
,
dim
,
device
=
device
,
dtype
=
itype
)
conv_state
=
torch
.
randn
(
batch
,
dim
,
width
,
device
=
device
,
dtype
=
itype
)
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
,
requires_grad
=
True
)
if
has_bias
:
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
,
requires_grad
=
True
)
else
:
bias
=
None
conv_state_ref
=
conv_state
.
detach
().
clone
()
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation
=
activation
)
out_ref
=
causal_conv1d_update_ref
(
x
,
conv_state_ref
,
weight
,
bias
,
activation
=
activation
)
assert
torch
.
equal
(
conv_state
,
conv_state_ref
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
tests/kernels/test_mamba_ssm.py
0 → 100644
View file @
fdd9daaf
import
pytest
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
vllm.model_executor.layers.mamba.ops.mamba_ssm
import
(
selective_scan_fn
,
selective_state_update
)
def
selective_state_update_ref
(
state
,
x
,
dt
,
A
,
B
,
C
,
D
=
None
,
z
=
None
,
dt_bias
=
None
,
dt_softplus
=
False
):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
x: (batch, dim) or (batch, nheads, dim)
dt: (batch, dim) or (batch, nheads, dim)
A: (dim, dstate) or (nheads, dim, dstate)
B: (batch, dstate) or (batch, ngroups, dstate)
C: (batch, dstate) or (batch, ngroups, dstate)
D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim)
Return:
out: (batch, dim) or (batch, nheads, dim)
"""
has_heads
=
state
.
dim
()
>
3
if
state
.
dim
()
==
3
:
state
=
state
.
unsqueeze
(
1
)
if
x
.
dim
()
==
2
:
x
=
x
.
unsqueeze
(
1
)
if
dt
.
dim
()
==
2
:
dt
=
dt
.
unsqueeze
(
1
)
if
A
.
dim
()
==
2
:
A
=
A
.
unsqueeze
(
0
)
if
B
.
dim
()
==
2
:
B
=
B
.
unsqueeze
(
1
)
if
C
.
dim
()
==
2
:
C
=
C
.
unsqueeze
(
1
)
if
D
is
not
None
and
D
.
dim
()
==
1
:
D
=
D
.
unsqueeze
(
0
)
if
z
is
not
None
and
z
.
dim
()
==
2
:
z
=
z
.
unsqueeze
(
1
)
if
dt_bias
is
not
None
and
dt_bias
.
dim
()
==
1
:
dt_bias
=
dt_bias
.
unsqueeze
(
0
)
batch
,
nheads
,
dim
,
dstate
=
state
.
shape
assert
x
.
shape
==
(
batch
,
nheads
,
dim
)
assert
dt
.
shape
==
x
.
shape
assert
A
.
shape
==
(
nheads
,
dim
,
dstate
)
ngroups
=
B
.
shape
[
1
]
assert
nheads
%
ngroups
==
0
,
"nheads must be divisible by ngroups"
assert
B
.
shape
==
(
batch
,
ngroups
,
dstate
)
assert
C
.
shape
==
B
.
shape
if
D
is
not
None
:
assert
D
.
shape
==
(
nheads
,
dim
)
if
z
is
not
None
:
assert
z
.
shape
==
x
.
shape
if
dt_bias
is
not
None
:
assert
dt_bias
.
shape
==
(
nheads
,
dim
)
dt
=
dt
+
dt_bias
dt
=
F
.
softplus
(
dt
)
if
dt_softplus
else
dt
dA
=
torch
.
exp
(
rearrange
(
dt
,
"b h d -> b h d 1"
)
*
A
)
# (batch, nheads, dim, dstate)
B
=
repeat
(
B
,
"b g n -> b (g h) n"
,
h
=
nheads
//
ngroups
)
# (batch, nheads, dstate)
C
=
repeat
(
C
,
"b g n -> b (g h) n"
,
h
=
nheads
//
ngroups
)
# (batch, nheads, dstate)
dB
=
rearrange
(
dt
,
"b h d -> b h d 1"
)
*
rearrange
(
B
,
"b h n -> b h 1 n"
)
# (batch, nheads, dim, dstate)
state
.
copy_
(
state
*
dA
+
dB
*
rearrange
(
x
,
"b h d -> b h d 1"
))
# (batch, dim, dstate
out
=
torch
.
einsum
(
"bhdn,bhn->bhd"
,
state
.
to
(
C
.
dtype
),
C
)
if
D
is
not
None
:
out
+=
(
x
*
D
).
to
(
out
.
dtype
)
out
=
(
out
if
z
is
None
else
out
*
F
.
silu
(
z
)).
to
(
x
.
dtype
)
if
not
has_heads
:
out
=
out
.
squeeze
(
1
)
return
out
def
selective_scan_ref
(
u
,
delta
,
A
,
B
,
C
,
D
=
None
,
z
=
None
,
delta_bias
=
None
,
delta_softplus
=
False
,
return_last_state
=
False
,
position_indices
=
None
,
prev_state
=
None
):
"""
u: r(B D L)
delta: r(B D L)
A: c(D N) or r(D N)
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
prev_state: r(B D N), fp32
out: r(B D L)
last_state (optional): r(B D dstate) or c(B D dstate)
"""
dtype_in
=
u
.
dtype
u
=
u
.
float
()
delta
=
delta
.
float
()
if
delta_bias
is
not
None
:
delta
=
delta
+
delta_bias
[...,
None
].
float
()
if
delta_softplus
:
delta
=
F
.
softplus
(
delta
)
batch
,
dim
,
dstate
=
u
.
shape
[
0
],
A
.
shape
[
0
],
A
.
shape
[
1
]
is_variable_B
=
B
.
dim
()
>=
3
is_variable_C
=
C
.
dim
()
>=
3
B
=
B
.
float
()
C
=
C
.
float
()
x
=
A
.
new_zeros
((
batch
,
dim
,
dstate
))
if
prev_state
is
None
else
prev_state
ys
=
[]
deltaA
=
torch
.
exp
(
torch
.
einsum
(
'bdl,dn->bdln'
,
delta
,
A
))
if
not
is_variable_B
:
deltaB_u
=
torch
.
einsum
(
'bdl,dn,bdl->bdln'
,
delta
,
B
,
u
)
else
:
if
B
.
dim
()
==
3
:
deltaB_u
=
torch
.
einsum
(
'bdl,bnl,bdl->bdln'
,
delta
,
B
,
u
)
else
:
B
=
repeat
(
B
,
"B G N L -> B (G H) N L"
,
H
=
dim
//
B
.
shape
[
1
])
deltaB_u
=
torch
.
einsum
(
'bdl,bdnl,bdl->bdln'
,
delta
,
B
,
u
)
if
is_variable_C
and
C
.
dim
()
==
4
:
C
=
repeat
(
C
,
"B G N L -> B (G H) N L"
,
H
=
dim
//
C
.
shape
[
1
])
last_state
=
None
for
i
in
range
(
u
.
shape
[
2
]):
if
position_indices
is
not
None
and
position_indices
[
0
,
i
]
==
0
:
x
=
deltaB_u
[:,
:,
i
]
else
:
x
=
deltaA
[:,
:,
i
]
*
x
+
deltaB_u
[:,
:,
i
]
if
not
is_variable_C
:
y
=
torch
.
einsum
(
'bdn,dn->bd'
,
x
,
C
)
else
:
if
C
.
dim
()
==
3
:
y
=
torch
.
einsum
(
'bdn,bn->bd'
,
x
,
C
[:,
:,
i
])
else
:
y
=
torch
.
einsum
(
'bdn,bdn->bd'
,
x
,
C
[:,
:,
:,
i
])
if
i
==
u
.
shape
[
2
]
-
1
:
last_state
=
x
ys
.
append
(
y
)
y
=
torch
.
stack
(
ys
,
dim
=
2
)
# (batch dim L)
out
=
y
if
D
is
None
else
y
+
u
*
rearrange
(
D
,
"d -> d 1"
)
if
z
is
not
None
:
out
=
out
*
F
.
silu
(
z
)
out
=
out
.
to
(
dtype
=
dtype_in
)
return
out
if
not
return_last_state
else
(
out
,
last_state
)
@
pytest
.
mark
.
parametrize
(
'wtype'
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
'itype'
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
128
,
256
,
512
,
1024
,
2048
,
4096
])
@
pytest
.
mark
.
parametrize
(
"return_last_state"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
'has_delta_bias'
,
[
True
])
@
pytest
.
mark
.
parametrize
(
'delta_softplus'
,
[
True
])
@
pytest
.
mark
.
parametrize
(
'has_z'
,
[
True
])
@
pytest
.
mark
.
parametrize
(
'has_D'
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"varBC_groups"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"is_variable_C"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"is_variable_B"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"scan_chunks"
,
[
1
,
2
,
3
])
def
test_selective_scan
(
is_variable_B
,
is_variable_C
,
varBC_groups
,
has_D
,
has_z
,
has_delta_bias
,
delta_softplus
,
return_last_state
,
seqlen
,
itype
,
wtype
,
scan_chunks
):
if
varBC_groups
>
1
and
(
not
is_variable_B
or
not
is_variable_C
):
pytest
.
skip
()
# This config is not applicable
device
=
'cuda'
rtol
,
atol
=
(
6e-4
,
2e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
3e-2
,
5e-2
rtolw
,
atolw
=
(
1e-3
,
1e-3
)
if
has_z
:
# If we have z, the errors on the weights seem higher
rtolw
=
max
(
rtolw
,
rtol
)
atolw
=
max
(
atolw
,
atol
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
2
dim
=
4
dstate
=
8
A
=
(
-
0.5
*
torch
.
rand
(
dim
,
dstate
,
device
=
device
,
dtype
=
wtype
))
if
not
is_variable_B
:
B_shape
=
[
dim
,
dstate
]
elif
varBC_groups
==
1
:
B_shape
=
[
batch_size
,
dstate
,
seqlen
]
else
:
B_shape
=
[
batch_size
,
varBC_groups
,
dstate
,
seqlen
]
B
=
torch
.
randn
(
B_shape
,
device
=
device
,
dtype
=
wtype
if
not
is_variable_B
else
itype
)
if
not
is_variable_C
:
C_shape
=
[
dim
,
dstate
]
elif
varBC_groups
==
1
:
C_shape
=
[
batch_size
,
dstate
,
seqlen
]
else
:
C_shape
=
[
batch_size
,
varBC_groups
,
dstate
,
seqlen
]
C
=
torch
.
randn
(
C_shape
,
device
=
device
,
dtype
=
wtype
if
not
is_variable_C
else
itype
)
D
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
torch
.
float32
)
if
has_D
else
None
z
=
torch
.
randn
(
batch_size
,
dim
,
seqlen
,
device
=
device
,
dtype
=
itype
)
if
has_z
else
None
delta_bias
=
(
0.5
*
torch
.
rand
(
dim
,
device
=
device
,
dtype
=
torch
.
float32
)
)
if
has_delta_bias
else
None
u
=
torch
.
randn
(
batch_size
,
dim
,
seqlen
,
device
=
device
,
dtype
=
itype
)
delta
=
(
0.5
*
torch
.
rand
(
batch_size
,
dim
,
seqlen
,
device
=
device
,
dtype
=
itype
))
state
=
None
state_ref
=
None
out
=
None
out_ref
=
None
outs
=
[]
for
c
in
range
(
scan_chunks
):
chunked_prompt_len
=
seqlen
//
scan_chunks
chunk_start
=
chunked_prompt_len
*
c
chunk_end
=
chunked_prompt_len
*
(
c
+
1
)
if
c
==
scan_chunks
-
1
:
chunk_end
=
seqlen
_B
=
B
if
is_variable_B
:
_B
=
B
[...,
chunk_start
:
chunk_end
]
_C
=
C
if
is_variable_B
:
_C
=
C
[...,
chunk_start
:
chunk_end
]
_z
=
z
if
has_z
:
assert
z
is
not
None
_z
=
z
[...,
chunk_start
:
chunk_end
]
out
,
*
rest
=
selective_scan_fn
(
u
[...,
chunk_start
:
chunk_end
],
delta
[...,
chunk_start
:
chunk_end
],
A
,
_B
,
_C
,
D
,
z
=
_z
,
delta_bias
=
delta_bias
,
delta_softplus
=
delta_softplus
,
return_last_state
=
return_last_state
,
prev_state
=
state
if
c
>
0
else
None
)
outs
.
append
(
out
)
if
return_last_state
:
state
=
rest
[
0
]
if
len
(
outs
)
>
1
:
out
=
torch
.
cat
(
outs
,
dim
=-
1
)
out_ref
,
*
rest
=
selective_scan_ref
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
=
z
,
delta_bias
=
delta_bias
,
delta_softplus
=
delta_softplus
,
return_last_state
=
return_last_state
)
if
return_last_state
:
state_ref
=
rest
[
0
]
assert
out
is
not
None
and
out_ref
is
not
None
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
if
return_last_state
:
assert
state
is
not
None
and
state_ref
is
not
None
assert
torch
.
allclose
(
state
,
state_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"has_z"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"dstate"
,
[
16
,
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
def
test_selective_state_update
(
dim
,
dstate
,
has_z
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
5e-3
,
1e-2
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
if
torch
.
version
.
hip
:
atol
*=
2
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
1
state
=
torch
.
randn
(
batch_size
,
dim
,
dstate
,
dtype
=
itype
,
device
=
device
)
x
=
torch
.
randn
(
batch_size
,
dim
,
device
=
device
,
dtype
=
itype
)
dt
=
torch
.
randn
(
batch_size
,
dim
,
device
=
device
,
dtype
=
itype
)
dt_bias
=
torch
.
rand
(
dim
,
device
=
device
)
-
4.0
A
=
-
torch
.
rand
(
dim
,
dstate
,
device
=
device
)
-
1.0
B
=
torch
.
randn
(
batch_size
,
dstate
,
device
=
device
)
C
=
torch
.
randn
(
batch_size
,
dstate
,
device
=
device
)
D
=
torch
.
randn
(
dim
,
device
=
device
)
z
=
torch
.
randn_like
(
x
)
if
has_z
else
None
state_ref
=
state
.
detach
().
clone
()
out
=
selective_state_update
(
state
,
x
,
dt
,
A
,
B
,
C
,
D
=
D
,
z
=
z
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
)
out_ref
=
selective_state_update_ref
(
state_ref
,
x
,
dt
,
A
,
B
,
C
,
D
=
D
,
z
=
z
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
)
assert
torch
.
allclose
(
state
,
state_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
vllm/_custom_ops.py
View file @
fdd9daaf
...
@@ -500,6 +500,36 @@ def ggml_mul_mat_a8(
...
@@ -500,6 +500,36 @@ def ggml_mul_mat_a8(
return
torch
.
ops
.
_C
.
ggml_mul_mat_a8
(
W
,
X
,
quant_type
,
row
)
return
torch
.
ops
.
_C
.
ggml_mul_mat_a8
(
W
,
X
,
quant_type
,
row
)
# mamba
def
causal_conv1d_fwd
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
seq_idx_
:
Optional
[
torch
.
Tensor
],
initial_states_
:
Optional
[
torch
.
Tensor
],
final_states_out_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
causal_conv1d_fwd
(
x
,
weight
,
bias_
,
seq_idx_
,
initial_states_
,
final_states_out_
,
silu_activation
)
def
causal_conv1d_update
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias_
,
silu_activation
)
def
selective_scan_fwd
(
u
:
torch
.
Tensor
,
delta
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
D_
:
Optional
[
torch
.
Tensor
],
z_
:
Optional
[
torch
.
Tensor
],
delta_bias_
:
Optional
[
torch
.
Tensor
],
delta_softplus
:
bool
,
index_
:
Optional
[
torch
.
Tensor
],
x
:
Optional
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
return
torch
.
ops
.
_C
.
selective_scan_fwd
(
u
,
delta
,
A
,
B
,
C
,
D_
,
z_
,
delta_bias_
,
delta_softplus
,
index_
,
x
)
# moe
# moe
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/mamba/__init__.py
0 → 100644
View file @
fdd9daaf
vllm/model_executor/layers/mamba/ops/__init__.py
0 → 100644
View file @
fdd9daaf
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
0 → 100644
View file @
fdd9daaf
# Copyright (c) 2024, Tri Dao.
from
typing
import
Optional
import
torch
from
vllm
import
_custom_ops
as
ops
def
causal_conv1d_fn
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
seq_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
initial_states
:
Optional
[
torch
.
Tensor
]
=
None
,
return_final_states
:
bool
=
False
,
final_states_out
=
None
,
activation
:
str
=
"silu"
,
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
seq_idx: (batch, seqlen)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1), to be written to
activation: either None or "silu" or "swish"
out: (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
if
x
.
stride
(
2
)
!=
1
and
x
.
stride
(
1
)
!=
1
:
x
=
x
.
contiguous
()
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
if
seq_idx
is
not
None
:
assert
(
initial_states
is
None
),
"initial_states must be None if seq_idx is not None"
assert
(
not
return_final_states
),
"If seq_idx is not None, we don't return final_states_out"
seq_idx
=
seq_idx
.
contiguous
()
if
seq_idx
is
not
None
else
None
if
initial_states
is
not
None
and
(
initial_states
.
stride
(
2
)
!=
1
and
initial_states
.
stride
(
1
)
!=
1
):
initial_states
=
initial_states
.
contiguous
()
if
return_final_states
:
assert
(
x
.
stride
(
1
)
==
1
),
"Only channel-last layout support returning final_states_out"
if
final_states_out
is
not
None
:
assert
(
final_states_out
.
stride
(
2
)
==
1
or
final_states_out
.
stride
(
1
)
==
1
)
else
:
batch
,
dim
,
seqlen
=
x
.
shape
width
=
weight
.
shape
[
1
]
final_states_out
=
torch
.
empty
(
batch
,
width
-
1
,
dim
,
device
=
x
.
device
,
dtype
=
x
.
dtype
).
transpose
(
1
,
2
)
else
:
final_states_out
=
None
out
=
ops
.
causal_conv1d_fwd
(
x
,
weight
,
bias
,
seq_idx
,
initial_states
,
final_states_out
,
activation
in
[
"silu"
,
"swish"
])
return
(
out
,
None
)
if
not
return_final_states
else
(
out
,
final_states_out
)
def
causal_conv1d_update
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
None
):
"""
x: (batch, dim)
conv_state: (batch, dim, width)
weight: (dim, width)
bias: (dim,)
out: (batch, dim)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
activation_bool
=
activation
in
[
"silu"
,
"swish"
]
return
ops
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation_bool
)
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
0 → 100644
View file @
fdd9daaf
# Copyright (c) 2024, Tri Dao, Albert Gu.
import
torch
import
triton
import
triton.language
as
tl
from
packaging
import
version
from
vllm
import
_custom_ops
as
ops
TRITON3
=
version
.
parse
(
triton
.
__version__
)
>=
version
.
parse
(
"3.0.0"
)
if
TRITON3
:
@
triton
.
jit
def
softplus
(
dt
):
dt
=
tl
.
where
(
dt
<=
20.0
,
tl
.
math
.
log
(
tl
.
math
.
exp
(
dt
)
+
1
),
dt
)
return
dt
else
:
@
triton
.
jit
def
softplus
(
dt
):
dt
=
tl
.
where
(
dt
<=
20.0
,
tl
.
math
.
log1p
(
tl
.
exp
(
dt
)),
dt
)
return
dt
@
triton
.
heuristics
(
{
"HAS_DT_BIAS"
:
lambda
args
:
args
[
"dt_bias_ptr"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_D"
:
lambda
args
:
args
[
"D_ptr"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_Z"
:
lambda
args
:
args
[
"z_ptr"
]
is
not
None
})
@
triton
.
heuristics
(
{
"BLOCK_SIZE_DSTATE"
:
lambda
args
:
triton
.
next_power_of_2
(
args
[
"dstate"
])})
@
triton
.
jit
def
_selective_scan_update_kernel
(
# Pointers to matrices
state_ptr
,
x_ptr
,
dt_ptr
,
dt_bias_ptr
,
A_ptr
,
B_ptr
,
C_ptr
,
D_ptr
,
z_ptr
,
out_ptr
,
# Matrix dimensions
batch
,
nheads
,
dim
,
dstate
,
nheads_ngroups_ratio
,
# Strides
stride_state_batch
,
stride_state_head
,
stride_state_dim
,
stride_state_dstate
,
stride_x_batch
,
stride_x_head
,
stride_x_dim
,
stride_dt_batch
,
stride_dt_head
,
stride_dt_dim
,
stride_dt_bias_head
,
stride_dt_bias_dim
,
stride_A_head
,
stride_A_dim
,
stride_A_dstate
,
stride_B_batch
,
stride_B_group
,
stride_B_dstate
,
stride_C_batch
,
stride_C_group
,
stride_C_dstate
,
stride_D_head
,
stride_D_dim
,
stride_z_batch
,
stride_z_head
,
stride_z_dim
,
stride_out_batch
,
stride_out_head
,
stride_out_dim
,
# Meta-parameters
DT_SOFTPLUS
:
tl
.
constexpr
,
TIE_HDIM
:
tl
.
constexpr
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
HAS_DT_BIAS
:
tl
.
constexpr
,
HAS_D
:
tl
.
constexpr
,
HAS_Z
:
tl
.
constexpr
,
BLOCK_SIZE_DSTATE
:
tl
.
constexpr
,
):
pid_m
=
tl
.
program_id
(
axis
=
0
)
pid_b
=
tl
.
program_id
(
axis
=
1
)
pid_h
=
tl
.
program_id
(
axis
=
2
)
state_ptr
+=
pid_b
*
stride_state_batch
+
pid_h
*
stride_state_head
x_ptr
+=
pid_b
*
stride_x_batch
+
pid_h
*
stride_x_head
dt_ptr
+=
pid_b
*
stride_dt_batch
+
pid_h
*
stride_dt_head
if
HAS_DT_BIAS
:
dt_bias_ptr
+=
pid_h
*
stride_dt_bias_head
A_ptr
+=
pid_h
*
stride_A_head
B_ptr
+=
pid_b
*
stride_B_batch
+
(
pid_h
//
nheads_ngroups_ratio
)
*
stride_B_group
C_ptr
+=
pid_b
*
stride_C_batch
+
(
pid_h
//
nheads_ngroups_ratio
)
*
stride_C_group
if
HAS_Z
:
z_ptr
+=
pid_b
*
stride_z_batch
+
pid_h
*
stride_z_head
out_ptr
+=
pid_b
*
stride_out_batch
+
pid_h
*
stride_out_head
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_SIZE_DSTATE
)
state_ptrs
=
state_ptr
+
(
offs_m
[:,
None
]
*
stride_state_dim
+
offs_n
[
None
,
:]
*
stride_state_dstate
)
x_ptrs
=
x_ptr
+
offs_m
*
stride_x_dim
dt_ptrs
=
dt_ptr
+
offs_m
*
stride_dt_dim
if
HAS_DT_BIAS
:
dt_bias_ptrs
=
dt_bias_ptr
+
offs_m
*
stride_dt_bias_dim
if
HAS_D
:
D_ptr
+=
pid_h
*
stride_D_head
A_ptrs
=
A_ptr
+
(
offs_m
[:,
None
]
*
stride_A_dim
+
offs_n
[
None
,
:]
*
stride_A_dstate
)
B_ptrs
=
B_ptr
+
offs_n
*
stride_B_dstate
C_ptrs
=
C_ptr
+
offs_n
*
stride_C_dstate
if
HAS_D
:
D_ptrs
=
D_ptr
+
offs_m
*
stride_D_dim
if
HAS_Z
:
z_ptrs
=
z_ptr
+
offs_m
*
stride_z_dim
out_ptrs
=
out_ptr
+
offs_m
*
stride_out_dim
state
=
tl
.
load
(
state_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
),
other
=
0.0
)
x
=
tl
.
load
(
x_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
if
not
TIE_HDIM
:
dt
=
tl
.
load
(
dt_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_DT_BIAS
:
dt
+=
tl
.
load
(
dt_bias_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
if
DT_SOFTPLUS
:
dt
=
softplus
(
dt
)
A
=
tl
.
load
(
A_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
),
other
=
0.0
).
to
(
tl
.
float32
)
dA
=
tl
.
exp
(
A
*
dt
[:,
None
])
else
:
dt
=
tl
.
load
(
dt_ptr
).
to
(
tl
.
float32
)
if
HAS_DT_BIAS
:
dt
+=
tl
.
load
(
dt_bias_ptr
).
to
(
tl
.
float32
)
if
DT_SOFTPLUS
:
dt
=
softplus
(
dt
)
A
=
tl
.
load
(
A_ptr
).
to
(
tl
.
float32
)
dA
=
tl
.
exp
(
A
*
dt
)
# scalar, not a matrix
B
=
tl
.
load
(
B_ptrs
,
mask
=
offs_n
<
dstate
,
other
=
0.0
).
to
(
tl
.
float32
)
C
=
tl
.
load
(
C_ptrs
,
mask
=
offs_n
<
dstate
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_D
:
D
=
tl
.
load
(
D_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_Z
:
z
=
tl
.
load
(
z_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
dB
=
B
[
None
,
:]
*
dt
[:,
None
]
if
not
TIE_HDIM
else
B
*
dt
state
=
state
*
dA
+
dB
*
x
[:,
None
]
tl
.
store
(
state_ptrs
,
state
,
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
))
out
=
tl
.
sum
(
state
*
C
[
None
,
:],
axis
=
1
)
if
HAS_D
:
out
+=
x
*
D
if
HAS_Z
:
out
*=
z
*
tl
.
sigmoid
(
z
)
tl
.
store
(
out_ptrs
,
out
,
mask
=
offs_m
<
dim
)
def
selective_state_update
(
state
,
x
,
dt
,
A
,
B
,
C
,
D
=
None
,
z
=
None
,
dt_bias
=
None
,
dt_softplus
=
False
):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
x: (batch, dim) or (batch, nheads, dim)
dt: (batch, dim) or (batch, nheads, dim)
A: (dim, dstate) or (nheads, dim, dstate)
B: (batch, dstate) or (batch, ngroups, dstate)
C: (batch, dstate) or (batch, ngroups, dstate)
D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim)
Return:
out: (batch, dim) or (batch, nheads, dim)
"""
has_heads
=
state
.
dim
()
>
3
if
state
.
dim
()
==
3
:
state
=
state
.
unsqueeze
(
1
)
if
x
.
dim
()
==
2
:
x
=
x
.
unsqueeze
(
1
)
if
dt
.
dim
()
==
2
:
dt
=
dt
.
unsqueeze
(
1
)
if
A
.
dim
()
==
2
:
A
=
A
.
unsqueeze
(
0
)
if
B
.
dim
()
==
2
:
B
=
B
.
unsqueeze
(
1
)
if
C
.
dim
()
==
2
:
C
=
C
.
unsqueeze
(
1
)
if
D
is
not
None
and
D
.
dim
()
==
1
:
D
=
D
.
unsqueeze
(
0
)
if
z
is
not
None
and
z
.
dim
()
==
2
:
z
=
z
.
unsqueeze
(
1
)
if
dt_bias
is
not
None
and
dt_bias
.
dim
()
==
1
:
dt_bias
=
dt_bias
.
unsqueeze
(
0
)
batch
,
nheads
,
dim
,
dstate
=
state
.
shape
assert
x
.
shape
==
(
batch
,
nheads
,
dim
)
assert
dt
.
shape
==
x
.
shape
assert
A
.
shape
==
(
nheads
,
dim
,
dstate
)
ngroups
=
B
.
shape
[
1
]
assert
nheads
%
ngroups
==
0
,
"nheads must be divisible by ngroups"
assert
B
.
shape
==
(
batch
,
ngroups
,
dstate
)
assert
C
.
shape
==
B
.
shape
if
D
is
not
None
:
assert
D
.
shape
==
(
nheads
,
dim
)
if
z
is
not
None
:
assert
z
.
shape
==
x
.
shape
if
dt_bias
is
not
None
:
assert
dt_bias
.
shape
==
(
nheads
,
dim
)
out
=
torch
.
empty_like
(
x
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
dim
,
META
[
'BLOCK_SIZE_M'
]),
batch
,
nheads
)
z_strides
=
((
z
.
stride
(
0
),
z
.
stride
(
1
),
z
.
stride
(
2
))
if
z
is
not
None
else
(
0
,
0
,
0
))
# We don't want autotune since it will overwrite the state
# We instead tune by hand.
BLOCK_SIZE_M
,
num_warps
=
((
32
,
4
)
if
dstate
<=
16
else
((
16
,
4
)
if
dstate
<=
32
else
((
8
,
4
)
if
dstate
<=
64
else
((
4
,
4
)
if
dstate
<=
128
else
((
4
,
8
))))))
tie_hdim
=
A
.
stride
(
-
1
)
==
0
and
A
.
stride
(
-
2
)
==
0
and
dt
.
stride
(
-
1
)
==
0
and
dt_bias
.
stride
(
-
1
)
==
0
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
_selective_scan_update_kernel
[
grid
](
state
,
x
,
dt
,
dt_bias
,
A
,
B
,
C
,
D
,
z
,
out
,
batch
,
nheads
,
dim
,
dstate
,
nheads
//
ngroups
,
state
.
stride
(
0
),
state
.
stride
(
1
),
state
.
stride
(
2
),
state
.
stride
(
3
),
x
.
stride
(
0
),
x
.
stride
(
1
),
x
.
stride
(
2
),
dt
.
stride
(
0
),
dt
.
stride
(
1
),
dt
.
stride
(
2
),
*
(
dt_bias
.
stride
(
0
),
dt_bias
.
stride
(
1
))
if
dt_bias
is
not
None
else
0
,
A
.
stride
(
0
),
A
.
stride
(
1
),
A
.
stride
(
2
),
B
.
stride
(
0
),
B
.
stride
(
1
),
B
.
stride
(
2
),
C
.
stride
(
0
),
C
.
stride
(
1
),
C
.
stride
(
2
),
*
(
D
.
stride
(
0
),
D
.
stride
(
1
))
if
D
is
not
None
else
0
,
z_strides
[
0
],
z_strides
[
1
],
z_strides
[
2
],
out
.
stride
(
0
),
out
.
stride
(
1
),
out
.
stride
(
2
),
dt_softplus
,
tie_hdim
,
BLOCK_SIZE_M
,
num_warps
=
num_warps
,
)
if
not
has_heads
:
out
=
out
.
squeeze
(
1
)
return
out
def
selective_scan_fn
(
u
,
delta
,
A
,
B
,
C
,
D
=
None
,
z
=
None
,
delta_bias
=
None
,
delta_softplus
=
False
,
return_last_state
=
False
,
position_indices
=
None
,
prev_state
=
None
):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate).
"""
if
u
.
stride
(
-
1
)
!=
1
:
u
=
u
.
contiguous
()
if
delta
.
stride
(
-
1
)
!=
1
:
delta
=
delta
.
contiguous
()
if
D
is
not
None
:
D
=
D
.
contiguous
()
if
B
.
stride
(
-
1
)
!=
1
:
B
=
B
.
contiguous
()
if
C
.
stride
(
-
1
)
!=
1
:
C
=
C
.
contiguous
()
if
z
is
not
None
and
z
.
stride
(
-
1
)
!=
1
:
z
=
z
.
contiguous
()
if
B
.
dim
()
==
3
:
B
=
B
.
unsqueeze
(
1
)
if
C
.
dim
()
==
3
:
C
=
C
.
unsqueeze
(
1
)
n_chunks
=
int
((
u
.
shape
[
-
1
]
+
2048
-
1
)
/
2048
)
x
=
torch
.
zeros
((
u
.
shape
[
0
],
u
.
shape
[
1
],
n_chunks
,
int
(
A
.
shape
[
1
]
*
2
),
),
device
=
u
.
device
,
dtype
=
torch
.
float32
,
requires_grad
=
False
)
x
[:,
:,
0
,
0
::
2
]
=
1
if
prev_state
is
not
None
:
x
[:,
:,
0
,
1
::
2
].
copy_
(
prev_state
)
out
,
x
,
*
rest
=
ops
.
selective_scan_fwd
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
delta_softplus
,
position_indices
,
x
)
last_state
=
x
[:,
:,
-
1
,
1
::
2
]
# (batch, dim, dstate)
if
z
is
None
:
return
out
if
not
return_last_state
else
(
out
,
last_state
)
else
:
out_z
=
rest
[
0
]
return
out_z
if
not
return_last_state
else
(
out_z
,
last_state
)
vllm/model_executor/models/jamba.py
View file @
fdd9daaf
...
@@ -4,9 +4,6 @@ from dataclasses import dataclass
...
@@ -4,9 +4,6 @@ from dataclasses import dataclass
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
causal_conv1d
import
causal_conv1d_fn
,
causal_conv1d_update
from
mamba_ssm.ops.selective_scan_interface
import
selective_scan_fn
from
mamba_ssm.ops.triton.selective_state_update
import
selective_state_update
from
torch
import
nn
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
transformers
import
JambaConfig
from
transformers
import
JambaConfig
...
@@ -24,6 +21,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -24,6 +21,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_update
)
from
vllm.model_executor.layers.mamba.ops.mamba_ssm
import
(
selective_scan_fn
,
selective_state_update
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
...
@@ -161,7 +162,7 @@ class JambaMambaMixer(nn.Module):
...
@@ -161,7 +162,7 @@ class JambaMambaMixer(nn.Module):
(
self
.
conv_kernel_size
-
hidden_states
.
shape
[
-
1
],
0
))
(
self
.
conv_kernel_size
-
hidden_states
.
shape
[
-
1
],
0
))
cache_params
.
conv_state
.
copy_
(
conv_states
)
cache_params
.
conv_state
.
copy_
(
conv_states
)
hidden_states
=
causal_conv1d_fn
(
hidden_states
,
_
=
causal_conv1d_fn
(
hidden_states
,
hidden_states
,
conv_weights
,
conv_weights
,
self
.
conv1d
.
bias
,
self
.
conv1d
.
bias
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment