Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
MambaVision_pytorch
Commits
2eefe3d6
Commit
2eefe3d6
authored
Sep 29, 2024
by
luopl
Browse files
add mamba
parent
b7535e7c
Pipeline
#1735
failed with stages
in 0 seconds
Changes
65
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2606 additions
and
0 deletions
+2606
-0
mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu
mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu
+10
-0
mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu
mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu
+10
-0
mamba/csrc/selective_scan/selective_scan_fwd_kernel.cuh
mamba/csrc/selective_scan/selective_scan_fwd_kernel.cuh
+376
-0
mamba/csrc/selective_scan/static_switch.h
mamba/csrc/selective_scan/static_switch.h
+25
-0
mamba/csrc/selective_scan/uninitialized_copy.cuh
mamba/csrc/selective_scan/uninitialized_copy.cuh
+77
-0
mamba/evals/lm_harness_eval.py
mamba/evals/lm_harness_eval.py
+39
-0
mamba/mamba_ssm/__init__.py
mamba/mamba_ssm/__init__.py
+6
-0
mamba/mamba_ssm/distributed/__init__.py
mamba/mamba_ssm/distributed/__init__.py
+0
-0
mamba/mamba_ssm/distributed/distributed_utils.py
mamba/mamba_ssm/distributed/distributed_utils.py
+144
-0
mamba/mamba_ssm/distributed/tensor_parallel.py
mamba/mamba_ssm/distributed/tensor_parallel.py
+296
-0
mamba/mamba_ssm/models/__init__.py
mamba/mamba_ssm/models/__init__.py
+0
-0
mamba/mamba_ssm/models/config_mamba.py
mamba/mamba_ssm/models/config_mamba.py
+18
-0
mamba/mamba_ssm/models/mixer_seq_simple.py
mamba/mamba_ssm/models/mixer_seq_simple.py
+309
-0
mamba/mamba_ssm/modules/__init__.py
mamba/mamba_ssm/modules/__init__.py
+0
-0
mamba/mamba_ssm/modules/block.py
mamba/mamba_ssm/modules/block.py
+91
-0
mamba/mamba_ssm/modules/mamba2.py
mamba/mamba_ssm/modules/mamba2.py
+383
-0
mamba/mamba_ssm/modules/mamba2_simple.py
mamba/mamba_ssm/modules/mamba2_simple.py
+200
-0
mamba/mamba_ssm/modules/mamba_simple.py
mamba/mamba_ssm/modules/mamba_simple.py
+294
-0
mamba/mamba_ssm/modules/mha.py
mamba/mamba_ssm/modules/mha.py
+294
-0
mamba/mamba_ssm/modules/mlp.py
mamba/mamba_ssm/modules/mlp.py
+34
-0
No files found.
mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu
0 → 100644
View file @
2eefe3d6
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// Split into multiple files to compile in paralell
#include "selective_scan_fwd_kernel.cuh"
template
void
selective_scan_fwd_cuda
<
at
::
Half
,
float
>(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
template
void
selective_scan_fwd_cuda
<
at
::
Half
,
complex_t
>(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
\ No newline at end of file
mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu
0 → 100644
View file @
2eefe3d6
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// Split into multiple files to compile in paralell
#include "selective_scan_fwd_kernel.cuh"
template
void
selective_scan_fwd_cuda
<
float
,
float
>(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
template
void
selective_scan_fwd_cuda
<
float
,
complex_t
>(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
\ No newline at end of file
mamba/csrc/selective_scan/selective_scan_fwd_kernel.cuh
0 → 100644
View file @
2eefe3d6
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#ifndef USE_ROCM
#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>
#include <cub/block/block_scan.cuh>
#else
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "selective_scan.h"
#include "selective_scan_common.h"
#include "static_switch.h"
template
<
int
kNThreads_
,
int
kNItems_
,
int
kNRows_
,
bool
kIsEvenLen_
,
bool
kIsVariableB_
,
bool
kIsVariableC_
,
bool
kHasZ_
,
typename
input_t_
,
typename
weight_t_
>
struct
Selective_Scan_fwd_kernel_traits
{
static_assert
(
kNItems_
%
4
==
0
);
using
input_t
=
input_t_
;
using
weight_t
=
weight_t_
;
static
constexpr
int
kNThreads
=
kNThreads_
;
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
static
constexpr
int
kMinBlocks
=
kNThreads
<
128
?
5
:
3
;
static
constexpr
int
kNItems
=
kNItems_
;
static
constexpr
int
kNRows
=
kNRows_
;
static
constexpr
int
kNBytes
=
sizeof
(
input_t
);
static_assert
(
kNBytes
==
2
||
kNBytes
==
4
);
static
constexpr
int
kNElts
=
kNBytes
==
4
?
4
:
constexpr_min
(
8
,
kNItems
);
static_assert
(
kNItems
%
kNElts
==
0
);
static
constexpr
int
kNLoads
=
kNItems
/
kNElts
;
static
constexpr
bool
kIsComplex
=
std
::
is_same_v
<
weight_t
,
complex_t
>
;
static
constexpr
bool
kIsEvenLen
=
kIsEvenLen_
;
static
constexpr
bool
kIsVariableB
=
kIsVariableB_
;
static
constexpr
bool
kIsVariableC
=
kIsVariableC_
;
static
constexpr
bool
kHasZ
=
kHasZ_
;
static
constexpr
bool
kDirectIO
=
kIsEvenLen
&&
kNLoads
==
1
;
using
vec_t
=
typename
BytesToType
<
kNBytes
*
kNElts
>::
Type
;
using
scan_t
=
std
::
conditional_t
<!
kIsComplex
,
float2
,
float4
>
;
using
BlockLoadT
=
cub
::
BlockLoad
<
input_t
,
kNThreads
,
kNItems
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
;
using
BlockLoadVecT
=
cub
::
BlockLoad
<
vec_t
,
kNThreads
,
kNLoads
,
!
kDirectIO
?
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
:
cub
::
BLOCK_LOAD_DIRECT
>
;
using
BlockLoadWeightT
=
cub
::
BlockLoad
<
input_t
,
kNThreads
,
!
kIsComplex
?
kNItems
:
kNItems
*
2
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
;
using
BlockLoadWeightVecT
=
cub
::
BlockLoad
<
vec_t
,
kNThreads
,
!
kIsComplex
?
kNLoads
:
kNLoads
*
2
,
!
kDirectIO
?
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
:
cub
::
BLOCK_LOAD_DIRECT
>
;
using
BlockStoreT
=
cub
::
BlockStore
<
input_t
,
kNThreads
,
kNItems
,
cub
::
BLOCK_STORE_WARP_TRANSPOSE
>
;
using
BlockStoreVecT
=
cub
::
BlockStore
<
vec_t
,
kNThreads
,
kNLoads
,
!
kDirectIO
?
cub
::
BLOCK_STORE_WARP_TRANSPOSE
:
cub
::
BLOCK_STORE_DIRECT
>
;
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
using
BlockScanT
=
cub
::
BlockScan
<
scan_t
,
kNThreads
,
cub
::
BLOCK_SCAN_WARP_SCANS
>
;
static
constexpr
int
kSmemIOSize
=
custom_max
({
sizeof
(
typename
BlockLoadT
::
TempStorage
),
sizeof
(
typename
BlockLoadVecT
::
TempStorage
),
(
int
(
kIsVariableB
)
+
int
(
kIsVariableC
))
*
sizeof
(
typename
BlockLoadWeightT
::
TempStorage
),
(
int
(
kIsVariableB
)
+
int
(
kIsVariableC
))
*
sizeof
(
typename
BlockLoadWeightVecT
::
TempStorage
),
sizeof
(
typename
BlockStoreT
::
TempStorage
),
sizeof
(
typename
BlockStoreVecT
::
TempStorage
)});
static
constexpr
int
kSmemSize
=
kSmemIOSize
+
sizeof
(
typename
BlockScanT
::
TempStorage
);
};
template
<
typename
Ktraits
>
__global__
__launch_bounds__
(
Ktraits
::
kNThreads
,
Ktraits
::
kMinBlocks
)
void
selective_scan_fwd_kernel
(
SSMParamsBase
params
)
{
constexpr
bool
kIsComplex
=
Ktraits
::
kIsComplex
;
constexpr
bool
kIsVariableB
=
Ktraits
::
kIsVariableB
;
constexpr
bool
kIsVariableC
=
Ktraits
::
kIsVariableC
;
constexpr
bool
kHasZ
=
Ktraits
::
kHasZ
;
constexpr
int
kNThreads
=
Ktraits
::
kNThreads
;
constexpr
int
kNItems
=
Ktraits
::
kNItems
;
constexpr
int
kNRows
=
Ktraits
::
kNRows
;
constexpr
bool
kDirectIO
=
Ktraits
::
kDirectIO
;
using
input_t
=
typename
Ktraits
::
input_t
;
using
weight_t
=
typename
Ktraits
::
weight_t
;
using
scan_t
=
typename
Ktraits
::
scan_t
;
// Shared memory.
extern
__shared__
char
smem_
[];
// cast to lvalue reference of expected type
// char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
auto
&
smem_load
=
reinterpret_cast
<
typename
Ktraits
::
BlockLoadT
::
TempStorage
&>
(
smem_
);
auto
&
smem_load_weight
=
reinterpret_cast
<
typename
Ktraits
::
BlockLoadWeightT
::
TempStorage
&>
(
smem_
);
auto
&
smem_load_weight1
=
*
reinterpret_cast
<
typename
Ktraits
::
BlockLoadWeightT
::
TempStorage
*>
(
smem_
+
sizeof
(
typename
Ktraits
::
BlockLoadWeightT
::
TempStorage
));
auto
&
smem_store
=
reinterpret_cast
<
typename
Ktraits
::
BlockStoreT
::
TempStorage
&>
(
smem_
);
auto
&
smem_scan
=
*
reinterpret_cast
<
typename
Ktraits
::
BlockScanT
::
TempStorage
*>
(
smem_
+
Ktraits
::
kSmemIOSize
);
// weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);
// weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);
scan_t
*
smem_running_prefix
=
reinterpret_cast
<
scan_t
*>
(
smem_
+
Ktraits
::
kSmemSize
);
const
int
batch_id
=
blockIdx
.
x
;
const
int
dim_id
=
blockIdx
.
y
;
const
int
group_id
=
dim_id
/
(
params
.
dim_ngroups_ratio
);
input_t
*
u
=
reinterpret_cast
<
input_t
*>
(
params
.
u_ptr
)
+
batch_id
*
params
.
u_batch_stride
+
dim_id
*
kNRows
*
params
.
u_d_stride
;
input_t
*
delta
=
reinterpret_cast
<
input_t
*>
(
params
.
delta_ptr
)
+
batch_id
*
params
.
delta_batch_stride
+
dim_id
*
kNRows
*
params
.
delta_d_stride
;
weight_t
*
A
=
reinterpret_cast
<
weight_t
*>
(
params
.
A_ptr
)
+
dim_id
*
kNRows
*
params
.
A_d_stride
;
weight_t
*
B
=
reinterpret_cast
<
weight_t
*>
(
params
.
B_ptr
)
+
dim_id
*
kNRows
*
params
.
B_d_stride
;
input_t
*
Bvar
=
reinterpret_cast
<
input_t
*>
(
params
.
B_ptr
)
+
batch_id
*
params
.
B_batch_stride
+
group_id
*
params
.
B_group_stride
;
weight_t
*
C
=
reinterpret_cast
<
weight_t
*>
(
params
.
C_ptr
)
+
dim_id
*
kNRows
*
params
.
C_d_stride
;
input_t
*
Cvar
=
reinterpret_cast
<
input_t
*>
(
params
.
C_ptr
)
+
batch_id
*
params
.
C_batch_stride
+
group_id
*
params
.
C_group_stride
;
scan_t
*
x
=
reinterpret_cast
<
scan_t
*>
(
params
.
x_ptr
)
+
(
batch_id
*
params
.
dim
+
dim_id
*
kNRows
)
*
params
.
n_chunks
*
params
.
dstate
;
float
D_val
[
kNRows
]
=
{
0
};
if
(
params
.
D_ptr
!=
nullptr
)
{
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
D_val
[
r
]
=
reinterpret_cast
<
float
*>
(
params
.
D_ptr
)[
dim_id
*
kNRows
+
r
];
}
}
float
delta_bias
[
kNRows
]
=
{
0
};
if
(
params
.
delta_bias_ptr
!=
nullptr
)
{
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
delta_bias
[
r
]
=
reinterpret_cast
<
float
*>
(
params
.
delta_bias_ptr
)[
dim_id
*
kNRows
+
r
];
}
}
// for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
// smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
// smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
// }
constexpr
int
kChunkSize
=
kNThreads
*
kNItems
;
for
(
int
chunk
=
0
;
chunk
<
params
.
n_chunks
;
++
chunk
)
{
input_t
u_vals
[
kNRows
][
kNItems
],
delta_vals_load
[
kNRows
][
kNItems
];
__syncthreads
();
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
if
constexpr
(
!
kDirectIO
)
{
if
(
r
>
0
)
{
__syncthreads
();
}
}
load_input
<
Ktraits
>
(
u
+
r
*
params
.
u_d_stride
,
u_vals
[
r
],
smem_load
,
params
.
seqlen
-
chunk
*
kChunkSize
);
if
constexpr
(
!
kDirectIO
)
{
__syncthreads
();
}
load_input
<
Ktraits
>
(
delta
+
r
*
params
.
delta_d_stride
,
delta_vals_load
[
r
],
smem_load
,
params
.
seqlen
-
chunk
*
kChunkSize
);
}
u
+=
kChunkSize
;
delta
+=
kChunkSize
;
float
delta_vals
[
kNRows
][
kNItems
],
delta_u_vals
[
kNRows
][
kNItems
],
out_vals
[
kNRows
][
kNItems
];
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
kNItems
;
++
i
)
{
float
u_val
=
float
(
u_vals
[
r
][
i
]);
delta_vals
[
r
][
i
]
=
float
(
delta_vals_load
[
r
][
i
])
+
delta_bias
[
r
];
if
(
params
.
delta_softplus
)
{
delta_vals
[
r
][
i
]
=
delta_vals
[
r
][
i
]
<=
20.
f
?
log1pf
(
expf
(
delta_vals
[
r
][
i
]))
:
delta_vals
[
r
][
i
];
}
delta_u_vals
[
r
][
i
]
=
delta_vals
[
r
][
i
]
*
u_val
;
out_vals
[
r
][
i
]
=
D_val
[
r
]
*
u_val
;
}
}
__syncthreads
();
for
(
int
state_idx
=
0
;
state_idx
<
params
.
dstate
;
++
state_idx
)
{
weight_t
A_val
[
kNRows
];
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
A_val
[
r
]
=
A
[
state_idx
*
params
.
A_dstate_stride
+
r
*
params
.
A_d_stride
];
// Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
constexpr
float
kLog2e
=
M_LOG2E
;
if
constexpr
(
!
kIsComplex
)
{
A_val
[
r
]
*=
kLog2e
;
}
else
{
A_val
[
r
].
real_
*=
kLog2e
;
}
}
// This variable holds B * C if both B and C are constant across seqlen. If only B varies
// across seqlen, this holds C. If only C varies across seqlen, this holds B.
// If both B and C vary, this is unused.
weight_t
BC_val
[
kNRows
];
weight_t
B_vals
[
kNItems
],
C_vals
[
kNItems
];
if
constexpr
(
kIsVariableB
)
{
load_weight
<
Ktraits
>
(
Bvar
+
state_idx
*
params
.
B_dstate_stride
,
B_vals
,
smem_load_weight
,
(
params
.
seqlen
-
chunk
*
kChunkSize
)
*
(
!
kIsComplex
?
1
:
2
));
if
constexpr
(
!
kIsVariableC
)
{
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
BC_val
[
r
]
=
C
[
state_idx
*
params
.
C_dstate_stride
+
r
*
params
.
C_d_stride
];
}
}
}
if
constexpr
(
kIsVariableC
)
{
auto
&
smem_load_weight_C
=
!
kIsVariableB
?
smem_load_weight
:
smem_load_weight1
;
load_weight
<
Ktraits
>
(
Cvar
+
state_idx
*
params
.
C_dstate_stride
,
C_vals
,
smem_load_weight_C
,
(
params
.
seqlen
-
chunk
*
kChunkSize
)
*
(
!
kIsComplex
?
1
:
2
));
if
constexpr
(
!
kIsVariableB
)
{
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
BC_val
[
r
]
=
B
[
state_idx
*
params
.
B_dstate_stride
+
r
*
params
.
B_d_stride
];
}
}
}
if
constexpr
(
!
kIsVariableB
&&
!
kIsVariableC
)
{
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
BC_val
[
r
]
=
B
[
state_idx
*
params
.
B_dstate_stride
+
r
*
params
.
B_d_stride
]
*
C
[
state_idx
*
params
.
C_dstate_stride
+
r
*
params
.
C_d_stride
];
}
}
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
if
(
r
>
0
)
{
__syncthreads
();
}
// Scan could be using the same smem
scan_t
thread_data
[
kNItems
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kNItems
;
++
i
)
{
if
constexpr
(
!
kIsComplex
)
{
thread_data
[
i
]
=
make_float2
(
exp2f
(
delta_vals
[
r
][
i
]
*
A_val
[
r
]),
!
kIsVariableB
?
delta_u_vals
[
r
][
i
]
:
B_vals
[
i
]
*
delta_u_vals
[
r
][
i
]);
if
constexpr
(
!
Ktraits
::
kIsEvenLen
)
{
// So that the last state is correct
if
(
threadIdx
.
x
*
kNItems
+
i
>=
params
.
seqlen
-
chunk
*
kChunkSize
)
{
thread_data
[
i
]
=
make_float2
(
1.
f
,
0.
f
);
}
}
}
else
{
// Pytorch's implementation of complex exp (which calls thrust) is very slow
complex_t
delta_a_exp
=
cexp2f
(
delta_vals
[
r
][
i
]
*
A_val
[
r
]);
weight_t
B_delta_u_val
=
!
kIsVariableB
?
delta_u_vals
[
r
][
i
]
:
B_vals
[
i
]
*
delta_u_vals
[
r
][
i
];
thread_data
[
i
]
=
make_float4
(
delta_a_exp
.
real_
,
delta_a_exp
.
imag_
,
B_delta_u_val
.
real_
,
B_delta_u_val
.
imag_
);
if
constexpr
(
!
Ktraits
::
kIsEvenLen
)
{
// So that the last state is correct
if
(
threadIdx
.
x
*
kNItems
+
i
>=
params
.
seqlen
-
chunk
*
kChunkSize
)
{
thread_data
[
i
]
=
make_float4
(
1.
f
,
0.
f
,
0.
f
,
0.
f
);
}
}
}
}
// Initialize running total
scan_t
running_prefix
;
if
constexpr
(
!
kIsComplex
)
{
// If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
running_prefix
=
chunk
>
0
&&
threadIdx
.
x
%
32
==
0
?
smem_running_prefix
[
state_idx
+
r
*
MAX_DSTATE
]
:
make_float2
(
1.
f
,
0.
f
);
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
}
else
{
running_prefix
=
chunk
>
0
&&
threadIdx
.
x
%
32
==
0
?
smem_running_prefix
[
state_idx
+
r
*
MAX_DSTATE
]
:
make_float4
(
1.
f
,
0.
f
,
0.
f
,
0.
f
);
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
}
SSMScanPrefixCallbackOp
<
weight_t
>
prefix_op
(
running_prefix
);
typename
Ktraits
::
BlockScanT
(
smem_scan
).
InclusiveScan
(
thread_data
,
thread_data
,
SSMScanOp
<
weight_t
>
(),
prefix_op
);
// There's a syncthreads in the scan op, so we don't need to sync here.
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
if
(
threadIdx
.
x
==
0
)
{
smem_running_prefix
[
state_idx
]
=
prefix_op
.
running_prefix
;
x
[(
r
*
params
.
n_chunks
+
chunk
)
*
params
.
dstate
+
state_idx
]
=
prefix_op
.
running_prefix
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
kNItems
;
++
i
)
{
const
weight_t
C_val
=
!
kIsVariableC
?
BC_val
[
r
]
:
(
!
kIsVariableB
?
BC_val
[
r
]
*
C_vals
[
i
]
:
C_vals
[
i
]);
if
constexpr
(
!
kIsComplex
)
{
out_vals
[
r
][
i
]
+=
thread_data
[
i
].
y
*
C_val
;
}
else
{
out_vals
[
r
][
i
]
+=
(
complex_t
(
thread_data
[
i
].
z
,
thread_data
[
i
].
w
)
*
C_val
).
real_
*
2
;
}
}
}
}
input_t
*
out
=
reinterpret_cast
<
input_t
*>
(
params
.
out_ptr
)
+
batch_id
*
params
.
out_batch_stride
+
dim_id
*
kNRows
*
params
.
out_d_stride
+
chunk
*
kChunkSize
;
__syncthreads
();
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
if
constexpr
(
!
kDirectIO
)
{
if
(
r
>
0
)
{
__syncthreads
();
}
}
store_output
<
Ktraits
>
(
out
+
r
*
params
.
out_d_stride
,
out_vals
[
r
],
smem_store
,
params
.
seqlen
-
chunk
*
kChunkSize
);
}
if
constexpr
(
kHasZ
)
{
input_t
*
z
=
reinterpret_cast
<
input_t
*>
(
params
.
z_ptr
)
+
batch_id
*
params
.
z_batch_stride
+
dim_id
*
kNRows
*
params
.
z_d_stride
+
chunk
*
kChunkSize
;
input_t
*
out_z
=
reinterpret_cast
<
input_t
*>
(
params
.
out_z_ptr
)
+
batch_id
*
params
.
out_z_batch_stride
+
dim_id
*
kNRows
*
params
.
out_z_d_stride
+
chunk
*
kChunkSize
;
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
input_t
z_vals
[
kNItems
];
__syncthreads
();
load_input
<
Ktraits
>
(
z
+
r
*
params
.
z_d_stride
,
z_vals
,
smem_load
,
params
.
seqlen
-
chunk
*
kChunkSize
);
#pragma unroll
for
(
int
i
=
0
;
i
<
kNItems
;
++
i
)
{
float
z_val
=
z_vals
[
i
];
out_vals
[
r
][
i
]
*=
z_val
/
(
1
+
expf
(
-
z_val
));
}
__syncthreads
();
store_output
<
Ktraits
>
(
out_z
+
r
*
params
.
out_z_d_stride
,
out_vals
[
r
],
smem_store
,
params
.
seqlen
-
chunk
*
kChunkSize
);
}
}
Bvar
+=
kChunkSize
*
(
!
kIsComplex
?
1
:
2
);
Cvar
+=
kChunkSize
*
(
!
kIsComplex
?
1
:
2
);
}
}
template
<
int
kNThreads
,
int
kNItems
,
typename
input_t
,
typename
weight_t
>
void
selective_scan_fwd_launch
(
SSMParamsBase
&
params
,
cudaStream_t
stream
)
{
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
// processing 1 row.
constexpr
int
kNRows
=
1
;
BOOL_SWITCH
(
params
.
seqlen
%
(
kNThreads
*
kNItems
)
==
0
,
kIsEvenLen
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_variable_B
,
kIsVariableB
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_variable_C
,
kIsVariableC
,
[
&
]
{
BOOL_SWITCH
(
params
.
z_ptr
!=
nullptr
,
kHasZ
,
[
&
]
{
using
Ktraits
=
Selective_Scan_fwd_kernel_traits
<
kNThreads
,
kNItems
,
kNRows
,
kIsEvenLen
,
kIsVariableB
,
kIsVariableC
,
kHasZ
,
input_t
,
weight_t
>
;
constexpr
int
kSmemSize
=
Ktraits
::
kSmemSize
+
kNRows
*
MAX_DSTATE
*
sizeof
(
typename
Ktraits
::
scan_t
);
dim3
grid
(
params
.
batch
,
params
.
dim
/
kNRows
);
// Had to change this substantially since potentially the hip
// interface for setting kernel launch attributes is slightly different from
// cuda's. In particualar, it seems to expect a plain const void * pointer.
auto
kernel
=
&
selective_scan_fwd_kernel
<
Ktraits
>
;
if
(
kSmemSize
>=
48
*
1024
)
{
#ifndef USE_ROCM
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kSmemSize
));
#else
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
(
void
*
)
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kSmemSize
));
std
::
cerr
<<
"Warning (selective_scan_fwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior.
\n
"
<<
std
::
endl
;
#endif
}
kernel
<<<
grid
,
Ktraits
::
kNThreads
,
kSmemSize
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
});
});
}
template
<
typename
input_t
,
typename
weight_t
>
void
selective_scan_fwd_cuda
(
SSMParamsBase
&
params
,
cudaStream_t
stream
)
{
#ifndef USE_ROCM
if
(
params
.
seqlen
<=
128
)
{
selective_scan_fwd_launch
<
32
,
4
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
256
)
{
selective_scan_fwd_launch
<
32
,
8
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
512
)
{
selective_scan_fwd_launch
<
32
,
16
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
1024
)
{
selective_scan_fwd_launch
<
64
,
16
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
{
selective_scan_fwd_launch
<
128
,
16
,
input_t
,
weight_t
>
(
params
,
stream
);
}
#else
if
(
params
.
seqlen
<=
256
)
{
selective_scan_fwd_launch
<
64
,
4
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
512
)
{
selective_scan_fwd_launch
<
64
,
8
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
1024
)
{
selective_scan_fwd_launch
<
64
,
16
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
{
selective_scan_fwd_launch
<
128
,
16
,
input_t
,
weight_t
>
(
params
,
stream
);
}
#endif
}
mamba/csrc/selective_scan/static_switch.h
0 → 100644
View file @
2eefe3d6
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
mamba/csrc/selective_scan/uninitialized_copy.cuh
0 → 100644
View file @
2eefe3d6
/******************************************************************************
* Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#ifndef USE_ROCM
#include <cub/config.cuh>
#include <cuda/std/type_traits>
#else
#include <hipcub/hipcub.hpp>
// Map ::cuda::std to the standard std namespace
namespace
cuda
{
namespace
std
=
::
std
;
}
#endif
namespace
detail
{
#if defined(_NVHPC_CUDA)
template
<
typename
T
,
typename
U
>
__host__
__device__
void
uninitialized_copy
(
T
*
ptr
,
U
&&
val
)
{
// NVBug 3384810
new
(
ptr
)
T
(
::
cuda
::
std
::
forward
<
U
>
(
val
));
}
#else
template
<
typename
T
,
typename
U
,
typename
::
cuda
::
std
::
enable_if
<
::
cuda
::
std
::
is_trivially_copyable
<
T
>
::
value
,
int
>::
type
=
0
>
__host__
__device__
void
uninitialized_copy
(
T
*
ptr
,
U
&&
val
)
{
*
ptr
=
::
cuda
::
std
::
forward
<
U
>
(
val
);
}
template
<
typename
T
,
typename
U
,
typename
::
cuda
::
std
::
enable_if
<
!::
cuda
::
std
::
is_trivially_copyable
<
T
>
::
value
,
int
>::
type
=
0
>
__host__
__device__
void
uninitialized_copy
(
T
*
ptr
,
U
&&
val
)
{
new
(
ptr
)
T
(
::
cuda
::
std
::
forward
<
U
>
(
val
));
}
#endif
}
// namespace detail
mamba/evals/lm_harness_eval.py
0 → 100644
View file @
2eefe3d6
import
torch
import
transformers
from
transformers
import
AutoTokenizer
from
mamba_ssm.models.mixer_seq_simple
import
MambaLMHeadModel
from
lm_eval.api.model
import
LM
from
lm_eval.models.huggingface
import
HFLM
from
lm_eval.api.registry
import
register_model
from
lm_eval.__main__
import
cli_evaluate
@
register_model
(
"mamba"
)
class
MambaEvalWrapper
(
HFLM
):
AUTO_MODEL_CLASS
=
transformers
.
AutoModelForCausalLM
def
__init__
(
self
,
pretrained
=
"state-spaces/mamba-2.8b"
,
max_length
=
2048
,
batch_size
=
None
,
device
=
"cuda"
,
dtype
=
torch
.
float16
):
LM
.
__init__
(
self
)
self
.
_model
=
MambaLMHeadModel
.
from_pretrained
(
pretrained
,
device
=
device
,
dtype
=
dtype
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"EleutherAI/gpt-neox-20b"
)
self
.
tokenizer
.
pad_token_id
=
self
.
tokenizer
.
eos_token_id
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
self
.
_batch_size
=
int
(
batch_size
)
if
batch_size
is
not
None
else
64
self
.
_max_length
=
max_length
self
.
_device
=
torch
.
device
(
device
)
@
property
def
batch_size
(
self
):
return
self
.
_batch_size
def
_model_generate
(
self
,
context
,
max_length
,
stop
,
**
generation_kwargs
):
raise
NotImplementedError
()
if
__name__
==
"__main__"
:
cli_evaluate
()
mamba/mamba_ssm/__init__.py
0 → 100644
View file @
2eefe3d6
__version__
=
"2.2.2"
from
mamba_ssm.ops.selective_scan_interface
import
selective_scan_fn
,
mamba_inner_fn
from
mamba_ssm.modules.mamba_simple
import
Mamba
from
mamba_ssm.modules.mamba2
import
Mamba2
from
mamba_ssm.models.mixer_seq_simple
import
MambaLMHeadModel
mamba/mamba_ssm/distributed/__init__.py
0 → 100644
View file @
2eefe3d6
mamba/mamba_ssm/distributed/distributed_utils.py
0 → 100644
View file @
2eefe3d6
from
typing
import
Optional
import
torch
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward compatibility with
# older PyTorch.
if
"all_gather_into_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
all_gather_into_tensor
=
torch
.
distributed
.
_all_gather_base
if
"reduce_scatter_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
reduce_scatter_tensor
=
torch
.
distributed
.
_reduce_scatter_base
# Raw operation, does not support autograd, but does support async
def
all_gather_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
output
=
torch
.
empty
(
world_size
*
input_
.
shape
[
0
],
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
handle
=
torch
.
distributed
.
all_gather_into_tensor
(
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
return
output
,
handle
# Raw operation, does not support autograd, but does support async
def
reduce_scatter_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
assert
input_
.
shape
[
0
]
%
world_size
==
0
output
=
torch
.
empty
(
input_
.
shape
[
0
]
//
world_size
,
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
handle
=
torch
.
distributed
.
reduce_scatter_tensor
(
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
return
output
,
handle
# Raw operation, does not support autograd, but does support async
def
all_reduce_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
input_
=
input_
.
contiguous
()
handle
=
torch
.
distributed
.
all_reduce
(
input_
,
group
=
process_group
,
async_op
=
async_op
)
return
input_
,
handle
class
AllGatherFunc
(
torch
.
autograd
.
Function
):
"""Gather the input from sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
all_gather_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
grad_input
,
_
=
reduce_scatter_raw
(
grad_output
,
ctx
.
process_group
)
return
grad_input
,
None
# Supports autograd, but does not support async
all_gather
=
AllGatherFunc
.
apply
class
ReduceScatterFunc
(
torch
.
autograd
.
Function
):
"""Reduce scatter the input from the sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
reduce_scatter_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
grad_input
,
_
=
all_gather_raw
(
grad_output
,
ctx
.
process_group
)
return
grad_input
,
None
# Supports autograd, but does not support async
reduce_scatter
=
ReduceScatterFunc
.
apply
class
AllReduceFunc
(
torch
.
autograd
.
Function
):
"""Gather the input from sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
all_reduce_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
return
grad_output
,
None
# Supports autograd, but does not support async
all_reduce
=
AllReduceFunc
.
apply
def
sync_shared_params
(
model
:
torch
.
nn
.
Module
,
process_group
:
ProcessGroup
):
# We want to iterate over parameters with _shared_params=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
pamams_shared
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
if
getattr
(
p
,
"_shared_params"
,
False
)
}
for
_
,
p
in
sorted
(
pamams_shared
.
items
()):
with
torch
.
no_grad
():
# Broadcast needs src to be global rank, not group rank
torch
.
distributed
.
broadcast
(
p
,
src
=
torch
.
distributed
.
get_global_rank
(
process_group
,
0
),
group
=
process_group
)
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
def
allreduce_sequence_parallel_grad
(
model
:
torch
.
nn
.
Module
,
process_group
:
ProcessGroup
):
# We want to iterate over parameters with _sequence_parallel=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
if
getattr
(
p
,
"_sequence_parallel"
,
False
)
}
grads
=
[
p
.
grad
for
_
,
p
in
sorted
(
params_seqparallel
.
items
())]
if
grads
:
with
torch
.
no_grad
():
coalesced
=
torch
.
_utils
.
_flatten_dense_tensors
(
grads
)
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
process_group
)
for
buf
,
synced
in
zip
(
grads
,
torch
.
_utils
.
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
def
get_dim_for_local_rank
(
dim
:
int
,
world_size
:
int
,
local_rank
:
int
,
multiple_of
:
int
=
1
)
->
int
:
"""Get the dim for the local rank derived from splitting dim on world_size processes.
The split may not be even across the world_size processes.
"""
multiple
=
dim
//
multiple_of
div
=
multiple
//
world_size
mod
=
multiple
%
world_size
local_multiple
=
div
+
int
(
local_rank
<
mod
)
return
local_multiple
*
multiple_of
mamba/mamba_ssm/distributed/tensor_parallel.py
0 → 100644
View file @
2eefe3d6
# Copyright (c) 2024, Tri Dao.
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.distributed
import
ProcessGroup
from
einops
import
rearrange
from
mamba_ssm.distributed.distributed_utils
import
(
all_gather_raw
,
all_reduce
,
all_reduce_raw
,
reduce_scatter
,
reduce_scatter_raw
,
)
class
ParallelLinearFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight
,
bias
,
process_group
=
None
,
sequence_parallel
=
True
):
"""
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
"""
ctx
.
compute_weight_gradient
=
weight
.
requires_grad
ctx
.
process_group
=
process_group
ctx
.
sequence_parallel
=
sequence_parallel
if
torch
.
is_autocast_enabled
():
x
=
x
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
x
=
x
.
contiguous
()
if
process_group
is
not
None
and
sequence_parallel
:
# We want to kick off the all_gather early, before weight dtype conversion
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
else
:
total_x
=
x
if
torch
.
is_autocast_enabled
():
weight
=
weight
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
bias
=
bias
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
if
bias
is
not
None
else
None
weight
=
weight
.
contiguous
()
if
process_group
is
not
None
and
sequence_parallel
:
handle_x
.
wait
()
batch_shape
,
n
=
total_x
.
shape
[:
-
1
],
total_x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
output
=
F
.
linear
(
total_x
,
weight
,
bias
)
if
ctx
.
compute_weight_gradient
:
ctx
.
save_for_backward
(
x
,
weight
)
else
:
ctx
.
save_for_backward
(
weight
)
return
output
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
grad_output
=
grad_output
.
contiguous
()
process_group
=
ctx
.
process_group
sequence_parallel
=
ctx
.
sequence_parallel
if
ctx
.
compute_weight_gradient
:
x
,
weight
=
ctx
.
saved_tensors
if
process_group
is
not
None
and
sequence_parallel
:
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
else
:
total_x
=
x
else
:
(
weight
,)
=
ctx
.
saved_tensors
total_x
=
None
batch_shape
=
grad_output
.
shape
[:
-
1
]
batch_dim
=
batch_shape
.
numel
()
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
if
ctx
.
needs_input_grad
[
0
]:
grad_input
=
F
.
linear
(
grad_output
,
weight
.
t
())
grad_input
=
grad_input
.
reshape
(
*
batch_shape
,
grad_input
.
shape
[
-
1
])
if
process_group
is
not
None
:
reduce_fn
=
reduce_scatter_raw
if
sequence_parallel
else
all_reduce_raw
grad_input
,
handle_grad_input
=
reduce_fn
(
grad_input
,
process_group
,
async_op
=
True
)
else
:
grad_input
=
None
if
ctx
.
needs_input_grad
[
1
]:
assert
ctx
.
compute_weight_gradient
if
process_group
is
not
None
and
sequence_parallel
:
handle_x
.
wait
()
grad_weight
=
torch
.
einsum
(
"bo,bi->oi"
,
grad_output
,
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
])
)
else
:
grad_weight
=
None
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
ctx
.
needs_input_grad
[
2
]
else
None
if
process_group
is
not
None
and
ctx
.
needs_input_grad
[
0
]:
handle_grad_input
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
def
parallel_linear_func
(
x
:
Tensor
,
weight
:
Tensor
,
bias
:
Optional
[
Tensor
]
=
None
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
sequence_parallel
:
bool
=
True
,
):
return
ParallelLinearFunc
.
apply
(
x
,
weight
,
bias
,
process_group
,
sequence_parallel
)
class
ColumnParallelLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
process_group
:
ProcessGroup
,
bias
:
bool
=
True
,
sequence_parallel
=
True
,
multiple_of
=
1
,
device
=
None
,
dtype
=
None
,
)
->
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
out_features
%
multiple_of
:
raise
ValueError
(
f
"out_features (
{
out_features
}
) must be a multiple of
{
multiple_of
}
"
)
multiple
=
out_features
//
multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
div
=
multiple
//
world_size
mod
=
multiple
%
world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple
=
div
+
int
(
torch
.
distributed
.
get_rank
(
process_group
)
<
mod
)
super
().
__init__
(
in_features
,
local_multiple
*
multiple_of
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
)
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
def
forward
(
self
,
x
):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
return
parallel_linear_func
(
x
,
self
.
weight
,
self
.
bias
,
process_group
=
self
.
process_group
,
sequence_parallel
=
self
.
sequence_parallel
,
)
class
RowParallelLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
process_group
:
ProcessGroup
,
bias
:
bool
=
True
,
sequence_parallel
=
True
,
multiple_of
=
1
,
device
=
None
,
dtype
=
None
,
)
->
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
rank
=
torch
.
distributed
.
get_rank
(
process_group
)
if
in_features
%
multiple_of
:
raise
ValueError
(
f
"in_features (
{
in_features
}
) must be a multiple of
{
multiple_of
}
"
)
multiple
=
in_features
//
multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
div
=
multiple
//
world_size
mod
=
multiple
%
world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple
=
div
+
int
(
torch
.
distributed
.
get_rank
(
process_group
)
<
mod
)
# Only rank 0 will have bias
super
().
__init__
(
local_multiple
*
multiple_of
,
out_features
,
bias
=
bias
and
rank
==
0
,
device
=
device
,
dtype
=
dtype
,
)
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
def
forward
(
self
,
x
):
"""
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
a reduce_scatter of the result.
"""
out
=
parallel_linear_func
(
x
,
self
.
weight
,
self
.
bias
)
reduce_fn
=
reduce_scatter
if
self
.
sequence_parallel
else
all_reduce
return
reduce_fn
(
out
,
self
.
process_group
)
class
VocabParallelEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
*
args
,
process_group
=
None
,
padding_idx
=
None
,
**
kwargs
):
self
.
process_group
=
process_group
if
process_group
is
not
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
num_embeddings
%
world_size
!=
0
:
raise
ValueError
(
f
"num_embeddings (
{
num_embeddings
}
) must be divisible by "
f
"world_size (
{
world_size
}
)"
)
if
world_size
>
1
and
padding_idx
is
not
None
:
raise
RuntimeError
(
"ParallelEmbedding does not support padding_idx"
)
else
:
world_size
=
1
super
().
__init__
(
num_embeddings
//
world_size
,
*
args
,
padding_idx
=
padding_idx
,
**
kwargs
)
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
if
self
.
process_group
is
None
:
return
super
().
forward
(
input
)
else
:
rank
=
torch
.
distributed
.
get_rank
(
self
.
process_group
)
vocab_size
=
self
.
num_embeddings
vocab_start_index
,
vocab_end_index
=
rank
*
vocab_size
,
(
rank
+
1
)
*
vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
input_ids_mask
=
(
input
<
vocab_start_index
)
|
(
input
>=
vocab_end_index
)
input
=
input
-
vocab_start_index
input
[
input_ids_mask
]
=
0
embeddings
=
super
().
forward
(
input
)
embeddings
[
input_ids_mask
]
=
0.0
return
embeddings
class
ColumnParallelEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
*
args
,
process_group
=
None
,
**
kwargs
):
self
.
process_group
=
process_group
if
process_group
is
not
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
embedding_dim
%
world_size
!=
0
:
raise
ValueError
(
f
"embedding_dim (
{
embedding_dim
}
) must be divisible by "
f
"world_size (
{
world_size
}
)"
)
else
:
world_size
=
1
super
().
__init__
(
num_embeddings
,
embedding_dim
//
world_size
,
*
args
,
**
kwargs
)
class
ParallelEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
process_group
,
padding_idx
=
None
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
self
.
word_embeddings
=
VocabParallelEmbedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
process_group
=
process_group
,
**
factory_kwargs
,
)
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
max_position_embeddings
>
0
:
self
.
position_embeddings
=
ColumnParallelEmbedding
(
max_position_embeddings
,
embed_dim
,
process_group
=
process_group
,
**
factory_kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
combine_batch_seqlen_dim
=
False
):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size
,
seqlen
=
input_ids
.
shape
world_size
=
torch
.
distributed
.
get_world_size
(
self
.
process_group
)
embeddings
=
self
.
word_embeddings
(
input_ids
)
if
self
.
max_position_embeddings
>
0
:
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
if
world_size
<=
1
:
embeddings
=
embeddings
+
position_embeddings
else
:
partition_dim
=
self
.
position_embeddings
.
embedding_dim
rank
=
torch
.
distributed
.
get_rank
(
self
.
process_group
)
embeddings
[
...,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
]
+=
position_embeddings
if
combine_batch_seqlen_dim
:
embeddings
=
rearrange
(
embeddings
,
"b s d -> (b s) d"
)
reduce_fn
=
reduce_scatter
if
self
.
sequence_parallel
else
all_reduce
return
embeddings
if
world_size
<=
1
else
reduce_fn
(
embeddings
,
self
.
process_group
)
mamba/mamba_ssm/models/__init__.py
0 → 100644
View file @
2eefe3d6
mamba/mamba_ssm/models/config_mamba.py
0 → 100644
View file @
2eefe3d6
from
dataclasses
import
dataclass
,
field
@
dataclass
class
MambaConfig
:
d_model
:
int
=
2560
d_intermediate
:
int
=
0
n_layer
:
int
=
64
vocab_size
:
int
=
50277
ssm_cfg
:
dict
=
field
(
default_factory
=
dict
)
attn_layer_idx
:
list
=
field
(
default_factory
=
list
)
attn_cfg
:
dict
=
field
(
default_factory
=
dict
)
rms_norm
:
bool
=
True
residual_in_fp32
:
bool
=
True
fused_add_norm
:
bool
=
True
pad_vocab_size_multiple
:
int
=
8
tie_embeddings
:
bool
=
True
mamba/mamba_ssm/models/mixer_seq_simple.py
0 → 100644
View file @
2eefe3d6
# Copyright (c) 2023, Albert Gu, Tri Dao.
import
math
from
functools
import
partial
import
json
import
os
import
copy
from
collections
import
namedtuple
import
torch
import
torch.nn
as
nn
from
mamba_ssm.models.config_mamba
import
MambaConfig
from
mamba_ssm.modules.mamba_simple
import
Mamba
from
mamba_ssm.modules.mamba2
import
Mamba2
from
mamba_ssm.modules.mha
import
MHA
from
mamba_ssm.modules.mlp
import
GatedMLP
from
mamba_ssm.modules.block
import
Block
from
mamba_ssm.utils.generation
import
GenerationMixin
from
mamba_ssm.utils.hf
import
load_config_hf
,
load_state_dict_hf
try
:
from
mamba_ssm.ops.triton.layer_norm
import
RMSNorm
,
layer_norm_fn
,
rms_norm_fn
except
ImportError
:
RMSNorm
,
layer_norm_fn
,
rms_norm_fn
=
None
,
None
,
None
def
create_block
(
d_model
,
d_intermediate
,
ssm_cfg
=
None
,
attn_layer_idx
=
None
,
attn_cfg
=
None
,
norm_epsilon
=
1e-5
,
rms_norm
=
False
,
residual_in_fp32
=
False
,
fused_add_norm
=
False
,
layer_idx
=
None
,
device
=
None
,
dtype
=
None
,
):
if
ssm_cfg
is
None
:
ssm_cfg
=
{}
if
attn_layer_idx
is
None
:
attn_layer_idx
=
[]
if
attn_cfg
is
None
:
attn_cfg
=
{}
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
if
layer_idx
not
in
attn_layer_idx
:
# Create a copy of the config to modify
ssm_cfg
=
copy
.
deepcopy
(
ssm_cfg
)
if
ssm_cfg
is
not
None
else
{}
ssm_layer
=
ssm_cfg
.
pop
(
"layer"
,
"Mamba1"
)
if
ssm_layer
not
in
[
"Mamba1"
,
"Mamba2"
]:
raise
ValueError
(
f
"Invalid ssm_layer:
{
ssm_layer
}
, only support Mamba1 and Mamba2"
)
mixer_cls
=
partial
(
Mamba2
if
ssm_layer
==
"Mamba2"
else
Mamba
,
layer_idx
=
layer_idx
,
**
ssm_cfg
,
**
factory_kwargs
)
else
:
mixer_cls
=
partial
(
MHA
,
layer_idx
=
layer_idx
,
**
attn_cfg
,
**
factory_kwargs
)
norm_cls
=
partial
(
nn
.
LayerNorm
if
not
rms_norm
else
RMSNorm
,
eps
=
norm_epsilon
,
**
factory_kwargs
)
if
d_intermediate
==
0
:
mlp_cls
=
nn
.
Identity
else
:
mlp_cls
=
partial
(
GatedMLP
,
hidden_features
=
d_intermediate
,
out_features
=
d_model
,
**
factory_kwargs
)
block
=
Block
(
d_model
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_cls
,
fused_add_norm
=
fused_add_norm
,
residual_in_fp32
=
residual_in_fp32
,
)
block
.
layer_idx
=
layer_idx
return
block
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def
_init_weights
(
module
,
n_layer
,
initializer_range
=
0.02
,
# Now only used for embedding layer.
rescale_prenorm_residual
=
True
,
n_residuals_per_layer
=
1
,
# Change to 2 if we have MLP
):
if
isinstance
(
module
,
nn
.
Linear
):
if
module
.
bias
is
not
None
:
if
not
getattr
(
module
.
bias
,
"_no_reinit"
,
False
):
nn
.
init
.
zeros_
(
module
.
bias
)
elif
isinstance
(
module
,
nn
.
Embedding
):
nn
.
init
.
normal_
(
module
.
weight
,
std
=
initializer_range
)
if
rescale_prenorm_residual
:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for
name
,
p
in
module
.
named_parameters
():
if
name
in
[
"out_proj.weight"
,
"fc2.weight"
]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
nn
.
init
.
kaiming_uniform_
(
p
,
a
=
math
.
sqrt
(
5
))
with
torch
.
no_grad
():
p
/=
math
.
sqrt
(
n_residuals_per_layer
*
n_layer
)
class
MixerModel
(
nn
.
Module
):
def
__init__
(
self
,
d_model
:
int
,
n_layer
:
int
,
d_intermediate
:
int
,
vocab_size
:
int
,
ssm_cfg
=
None
,
attn_layer_idx
=
None
,
attn_cfg
=
None
,
norm_epsilon
:
float
=
1e-5
,
rms_norm
:
bool
=
False
,
initializer_cfg
=
None
,
fused_add_norm
=
False
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
,
)
->
None
:
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
residual_in_fp32
=
residual_in_fp32
self
.
embedding
=
nn
.
Embedding
(
vocab_size
,
d_model
,
**
factory_kwargs
)
# We change the order of residual and layer norm:
# Instead of LN -> Attn / MLP -> Add, we do:
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
# the main branch (output of MLP / Mixer). The model definition is unchanged.
# This is for performance reason: we can fuse add + layer_norm.
self
.
fused_add_norm
=
fused_add_norm
if
self
.
fused_add_norm
:
if
layer_norm_fn
is
None
or
rms_norm_fn
is
None
:
raise
ImportError
(
"Failed to import Triton LayerNorm / RMSNorm kernels"
)
self
.
layers
=
nn
.
ModuleList
(
[
create_block
(
d_model
,
d_intermediate
=
d_intermediate
,
ssm_cfg
=
ssm_cfg
,
attn_layer_idx
=
attn_layer_idx
,
attn_cfg
=
attn_cfg
,
norm_epsilon
=
norm_epsilon
,
rms_norm
=
rms_norm
,
residual_in_fp32
=
residual_in_fp32
,
fused_add_norm
=
fused_add_norm
,
layer_idx
=
i
,
**
factory_kwargs
,
)
for
i
in
range
(
n_layer
)
]
)
self
.
norm_f
=
(
nn
.
LayerNorm
if
not
rms_norm
else
RMSNorm
)(
d_model
,
eps
=
norm_epsilon
,
**
factory_kwargs
)
self
.
apply
(
partial
(
_init_weights
,
n_layer
=
n_layer
,
**
(
initializer_cfg
if
initializer_cfg
is
not
None
else
{}),
n_residuals_per_layer
=
1
if
d_intermediate
==
0
else
2
,
# 2 if we have MLP
)
)
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
return
{
i
:
layer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
for
i
,
layer
in
enumerate
(
self
.
layers
)
}
def
forward
(
self
,
input_ids
,
inference_params
=
None
,
**
mixer_kwargs
):
hidden_states
=
self
.
embedding
(
input_ids
)
residual
=
None
for
layer
in
self
.
layers
:
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
inference_params
=
inference_params
,
**
mixer_kwargs
)
if
not
self
.
fused_add_norm
:
residual
=
(
hidden_states
+
residual
)
if
residual
is
not
None
else
hidden_states
hidden_states
=
self
.
norm_f
(
residual
.
to
(
dtype
=
self
.
norm_f
.
weight
.
dtype
))
else
:
# Set prenorm=False here since we don't need the residual
hidden_states
=
layer_norm_fn
(
hidden_states
,
self
.
norm_f
.
weight
,
self
.
norm_f
.
bias
,
eps
=
self
.
norm_f
.
eps
,
residual
=
residual
,
prenorm
=
False
,
residual_in_fp32
=
self
.
residual_in_fp32
,
is_rms_norm
=
isinstance
(
self
.
norm_f
,
RMSNorm
)
)
return
hidden_states
class
MambaLMHeadModel
(
nn
.
Module
,
GenerationMixin
):
def
__init__
(
self
,
config
:
MambaConfig
,
initializer_cfg
=
None
,
device
=
None
,
dtype
=
None
,
)
->
None
:
self
.
config
=
config
d_model
=
config
.
d_model
n_layer
=
config
.
n_layer
d_intermediate
=
config
.
d_intermediate
vocab_size
=
config
.
vocab_size
ssm_cfg
=
config
.
ssm_cfg
attn_layer_idx
=
config
.
attn_layer_idx
attn_cfg
=
config
.
attn_cfg
rms_norm
=
config
.
rms_norm
residual_in_fp32
=
config
.
residual_in_fp32
fused_add_norm
=
config
.
fused_add_norm
pad_vocab_size_multiple
=
config
.
pad_vocab_size_multiple
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
if
vocab_size
%
pad_vocab_size_multiple
!=
0
:
vocab_size
+=
pad_vocab_size_multiple
-
(
vocab_size
%
pad_vocab_size_multiple
)
self
.
backbone
=
MixerModel
(
d_model
=
d_model
,
n_layer
=
n_layer
,
d_intermediate
=
d_intermediate
,
vocab_size
=
vocab_size
,
ssm_cfg
=
ssm_cfg
,
attn_layer_idx
=
attn_layer_idx
,
attn_cfg
=
attn_cfg
,
rms_norm
=
rms_norm
,
initializer_cfg
=
initializer_cfg
,
fused_add_norm
=
fused_add_norm
,
residual_in_fp32
=
residual_in_fp32
,
**
factory_kwargs
,
)
self
.
lm_head
=
nn
.
Linear
(
d_model
,
vocab_size
,
bias
=
False
,
**
factory_kwargs
)
# Initialize weights and apply final processing
self
.
apply
(
partial
(
_init_weights
,
n_layer
=
n_layer
,
**
(
initializer_cfg
if
initializer_cfg
is
not
None
else
{}),
)
)
self
.
tie_weights
()
def
tie_weights
(
self
):
if
self
.
config
.
tie_embeddings
:
self
.
lm_head
.
weight
=
self
.
backbone
.
embedding
.
weight
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
return
self
.
backbone
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
,
num_last_tokens
=
0
,
**
mixer_kwargs
):
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states
=
self
.
backbone
(
input_ids
,
inference_params
=
inference_params
,
**
mixer_kwargs
)
if
num_last_tokens
>
0
:
hidden_states
=
hidden_states
[:,
-
num_last_tokens
:]
lm_logits
=
self
.
lm_head
(
hidden_states
)
CausalLMOutput
=
namedtuple
(
"CausalLMOutput"
,
[
"logits"
])
return
CausalLMOutput
(
logits
=
lm_logits
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name
,
device
=
None
,
dtype
=
None
,
**
kwargs
):
config_data
=
load_config_hf
(
pretrained_model_name
)
config
=
MambaConfig
(
**
config_data
)
model
=
cls
(
config
,
device
=
device
,
dtype
=
dtype
,
**
kwargs
)
model
.
load_state_dict
(
load_state_dict_hf
(
pretrained_model_name
,
device
=
device
,
dtype
=
dtype
))
return
model
def
save_pretrained
(
self
,
save_directory
):
"""
Minimal implementation of save_pretrained for MambaLMHeadModel.
Save the model and its configuration file to a directory.
"""
# Ensure save_directory exists
os
.
makedirs
(
save_directory
,
exist_ok
=
True
)
# Save the model's state_dict
model_path
=
os
.
path
.
join
(
save_directory
,
'pytorch_model.bin'
)
torch
.
save
(
self
.
state_dict
(),
model_path
)
# Save the configuration of the model
config_path
=
os
.
path
.
join
(
save_directory
,
'config.json'
)
with
open
(
config_path
,
'w'
)
as
f
:
json
.
dump
(
self
.
config
.
__dict__
,
f
,
indent
=
4
)
mamba/mamba_ssm/modules/__init__.py
0 → 100644
View file @
2eefe3d6
mamba/mamba_ssm/modules/block.py
0 → 100644
View file @
2eefe3d6
# Copyright (c) 2024, Tri Dao, Albert Gu.
from
typing
import
Optional
import
torch
from
torch
import
nn
,
Tensor
from
mamba_ssm.ops.triton.layer_norm
import
RMSNorm
,
layer_norm_fn
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
mixer_cls
,
mlp_cls
,
norm_cls
=
nn
.
LayerNorm
,
fused_add_norm
=
False
,
residual_in_fp32
=
False
):
"""
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
This Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA/MLP -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Add -> LN -> Mixer, returning both
the hidden_states (output of the mixer) and the residual.
This is purely for performance reasons, as we can fuse add and LayerNorm.
The residual needs to be provided (except for the very first block).
"""
super
().
__init__
()
self
.
residual_in_fp32
=
residual_in_fp32
self
.
fused_add_norm
=
fused_add_norm
self
.
norm
=
norm_cls
(
dim
)
self
.
mixer
=
mixer_cls
(
dim
)
if
mlp_cls
is
not
nn
.
Identity
:
self
.
norm2
=
norm_cls
(
dim
)
self
.
mlp
=
mlp_cls
(
dim
)
else
:
self
.
mlp
=
None
if
self
.
fused_add_norm
:
assert
RMSNorm
is
not
None
,
"RMSNorm import fails"
assert
isinstance
(
self
.
norm
,
(
nn
.
LayerNorm
,
RMSNorm
)
),
"Only LayerNorm and RMSNorm are supported for fused_add_norm"
def
forward
(
self
,
hidden_states
:
Tensor
,
residual
:
Optional
[
Tensor
]
=
None
,
inference_params
=
None
,
**
mixer_kwargs
):
r
"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: hidden_states = Mixer(LN(residual))
"""
if
not
self
.
fused_add_norm
:
residual
=
(
hidden_states
+
residual
)
if
residual
is
not
None
else
hidden_states
hidden_states
=
self
.
norm
(
residual
.
to
(
dtype
=
self
.
norm
.
weight
.
dtype
))
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
else
:
hidden_states
,
residual
=
layer_norm_fn
(
hidden_states
,
self
.
norm
.
weight
,
self
.
norm
.
bias
,
residual
=
residual
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
,
eps
=
self
.
norm
.
eps
,
is_rms_norm
=
isinstance
(
self
.
norm
,
RMSNorm
)
)
hidden_states
=
self
.
mixer
(
hidden_states
,
inference_params
=
inference_params
,
**
mixer_kwargs
)
if
self
.
mlp
is
not
None
:
if
not
self
.
fused_add_norm
:
residual
=
hidden_states
+
residual
hidden_states
=
self
.
norm2
(
residual
.
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
))
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
else
:
hidden_states
,
residual
=
layer_norm_fn
(
hidden_states
,
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
residual
=
residual
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
,
eps
=
self
.
norm2
.
eps
,
is_rms_norm
=
isinstance
(
self
.
norm2
,
RMSNorm
)
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
return
self
.
mixer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
mamba/mamba_ssm/modules/mamba2.py
0 → 100644
View file @
2eefe3d6
# Copyright (c) 2024, Tri Dao, Albert Gu.
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
try
:
from
causal_conv1d
import
causal_conv1d_fn
,
causal_conv1d_update
except
ImportError
:
causal_conv1d_fn
,
causal_conv1d_update
=
None
,
None
try
:
from
causal_conv1d.causal_conv1d_varlen
import
causal_conv1d_varlen_states
except
ImportError
:
causal_conv1d_varlen_states
=
None
try
:
from
mamba_ssm.ops.triton.selective_state_update
import
selective_state_update
except
ImportError
:
selective_state_update
=
None
from
mamba_ssm.ops.triton.layernorm_gated
import
RMSNorm
as
RMSNormGated
from
mamba_ssm.distributed.tensor_parallel
import
ColumnParallelLinear
,
RowParallelLinear
from
mamba_ssm.distributed.distributed_utils
import
all_reduce
,
reduce_scatter
from
mamba_ssm.ops.triton.ssd_combined
import
mamba_chunk_scan_combined
from
mamba_ssm.ops.triton.ssd_combined
import
mamba_split_conv1d_scan_combined
from
huggingface_hub
import
PyTorchModelHubMixin
class
Mamba2
(
nn
.
Module
,
PyTorchModelHubMixin
):
def
__init__
(
self
,
d_model
,
d_state
=
128
,
d_conv
=
4
,
conv_init
=
None
,
expand
=
2
,
headdim
=
64
,
d_ssm
=
None
,
# If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
ngroups
=
1
,
A_init_range
=
(
1
,
16
),
D_has_hdim
=
False
,
rmsnorm
=
True
,
norm_before_gate
=
False
,
dt_min
=
0.001
,
dt_max
=
0.1
,
dt_init_floor
=
1e-4
,
dt_limit
=
(
0.0
,
float
(
"inf"
)),
bias
=
False
,
conv_bias
=
True
,
# Fused kernel and sharding options
chunk_size
=
256
,
use_mem_eff_path
=
True
,
layer_idx
=
None
,
# Absorb kwarg for general module
process_group
=
None
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
d_model
=
d_model
self
.
d_state
=
d_state
self
.
d_conv
=
d_conv
self
.
conv_init
=
conv_init
self
.
expand
=
expand
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
self
.
world_size
=
1
if
process_group
is
None
else
process_group
.
size
()
self
.
local_rank
=
0
if
process_group
is
None
else
process_group
.
rank
()
self
.
d_inner
=
(
self
.
expand
*
self
.
d_model
)
//
self
.
world_size
assert
self
.
d_inner
*
self
.
world_size
==
self
.
expand
*
self
.
d_model
self
.
headdim
=
headdim
self
.
d_ssm
=
self
.
d_inner
if
d_ssm
is
None
else
d_ssm
//
self
.
world_size
assert
ngroups
%
self
.
world_size
==
0
self
.
ngroups
=
ngroups
//
self
.
world_size
assert
self
.
d_ssm
%
self
.
headdim
==
0
self
.
nheads
=
self
.
d_ssm
//
self
.
headdim
self
.
D_has_hdim
=
D_has_hdim
self
.
rmsnorm
=
rmsnorm
self
.
norm_before_gate
=
norm_before_gate
self
.
dt_limit
=
dt_limit
self
.
activation
=
"silu"
self
.
chunk_size
=
chunk_size
self
.
use_mem_eff_path
=
use_mem_eff_path
self
.
layer_idx
=
layer_idx
# Order: [z, x, B, C, dt]
d_in_proj
=
2
*
self
.
d_inner
+
2
*
self
.
ngroups
*
self
.
d_state
+
self
.
nheads
if
self
.
process_group
is
None
:
self
.
in_proj
=
nn
.
Linear
(
self
.
d_model
,
d_in_proj
,
bias
=
bias
,
**
factory_kwargs
)
else
:
self
.
in_proj
=
ColumnParallelLinear
(
self
.
d_model
,
d_in_proj
*
self
.
world_size
,
bias
=
bias
,
process_group
=
self
.
process_group
,
sequence_parallel
=
self
.
sequence_parallel
,
**
factory_kwargs
)
conv_dim
=
self
.
d_ssm
+
2
*
self
.
ngroups
*
self
.
d_state
self
.
conv1d
=
nn
.
Conv1d
(
in_channels
=
conv_dim
,
out_channels
=
conv_dim
,
bias
=
conv_bias
,
kernel_size
=
d_conv
,
groups
=
conv_dim
,
padding
=
d_conv
-
1
,
**
factory_kwargs
,
)
if
self
.
conv_init
is
not
None
:
nn
.
init
.
uniform_
(
self
.
conv1d
.
weight
,
-
self
.
conv_init
,
self
.
conv_init
)
self
.
act
=
nn
.
SiLU
()
# Initialize log dt bias
dt
=
torch
.
exp
(
torch
.
rand
(
self
.
nheads
,
**
factory_kwargs
)
*
(
math
.
log
(
dt_max
)
-
math
.
log
(
dt_min
))
+
math
.
log
(
dt_min
)
)
dt
=
torch
.
clamp
(
dt
,
min
=
dt_init_floor
)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt
=
dt
+
torch
.
log
(
-
torch
.
expm1
(
-
dt
))
self
.
dt_bias
=
nn
.
Parameter
(
inv_dt
)
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self
.
dt_bias
.
_no_weight_decay
=
True
assert
A_init_range
[
0
]
>
0
and
A_init_range
[
1
]
>=
A_init_range
[
0
]
A
=
torch
.
empty
(
self
.
nheads
,
dtype
=
torch
.
float32
,
device
=
device
).
uniform_
(
*
A_init_range
)
A_log
=
torch
.
log
(
A
).
to
(
dtype
=
dtype
)
self
.
A_log
=
nn
.
Parameter
(
A_log
)
self
.
A_log
.
_no_weight_decay
=
True
# D "skip" parameter
self
.
D
=
nn
.
Parameter
(
torch
.
ones
(
self
.
d_ssm
if
self
.
D_has_hdim
else
self
.
nheads
,
device
=
device
))
self
.
D
.
_no_weight_decay
=
True
if
self
.
rmsnorm
:
assert
RMSNormGated
is
not
None
self
.
norm
=
RMSNormGated
(
self
.
d_ssm
,
eps
=
1e-5
,
norm_before_gate
=
self
.
norm_before_gate
,
group_size
=
self
.
d_ssm
//
ngroups
,
**
factory_kwargs
)
if
self
.
process_group
is
None
:
self
.
out_proj
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
,
bias
=
bias
,
**
factory_kwargs
)
else
:
self
.
out_proj
=
RowParallelLinear
(
self
.
d_inner
*
self
.
world_size
,
self
.
d_model
,
bias
=
bias
,
process_group
=
self
.
process_group
,
sequence_parallel
=
self
.
sequence_parallel
,
**
factory_kwargs
)
def
forward
(
self
,
u
,
seqlen
=
None
,
seq_idx
=
None
,
cu_seqlens
=
None
,
inference_params
=
None
):
"""
u: (batch, seqlen, hidden_dim) if seqlen=None.
If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
split u during sequence parallel, we split the batch * seqlen dimension
(in case batch is small).
Returns: same shape as u
"""
seqlen_og
=
seqlen
if
seqlen
is
None
:
batch
,
seqlen
,
dim
=
u
.
shape
else
:
batch_seqlen
,
dim
=
u
.
shape
batch
=
batch_seqlen
//
seqlen
conv_state
,
ssm_state
=
None
,
None
if
inference_params
is
not
None
:
inference_batch
=
cu_seqlens
.
shape
[
0
]
-
1
if
cu_seqlens
is
not
None
else
batch
conv_state
,
ssm_state
=
self
.
_get_states_from_cache
(
inference_params
,
inference_batch
)
if
inference_params
.
seqlen_offset
>
0
:
# The states are updated inplace
out
,
_
,
_
=
self
.
step
(
u
,
conv_state
,
ssm_state
)
return
out
zxbcdt
=
self
.
in_proj
(
u
)
# (B, L, d_in_proj) or (B * L, d_in_proj)
if
seqlen_og
is
not
None
:
zxbcdt
=
rearrange
(
zxbcdt
,
"(b l) d -> b l d"
,
l
=
seqlen
)
# If the model is loaded in fp16, without the .float() here, A might be -inf
A
=
-
torch
.
exp
(
self
.
A_log
.
float
())
# (nheads) or (d_inner, d_state)
dt_limit_kwargs
=
{}
if
self
.
dt_limit
==
(
0.0
,
float
(
"inf"
))
else
dict
(
dt_limit
=
self
.
dt_limit
)
if
self
.
use_mem_eff_path
and
inference_params
is
None
:
out
=
mamba_split_conv1d_scan_combined
(
zxbcdt
,
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
self
.
conv1d
.
bias
,
self
.
dt_bias
,
A
,
D
=
rearrange
(
self
.
D
,
"(h p) -> h p"
,
p
=
self
.
headdim
)
if
self
.
D_has_hdim
else
self
.
D
,
chunk_size
=
self
.
chunk_size
,
seq_idx
=
seq_idx
,
activation
=
self
.
activation
,
rmsnorm_weight
=
self
.
norm
.
weight
if
self
.
rmsnorm
else
None
,
rmsnorm_eps
=
self
.
norm
.
eps
if
self
.
rmsnorm
else
1e-6
,
outproj_weight
=
self
.
out_proj
.
weight
,
outproj_bias
=
self
.
out_proj
.
bias
,
headdim
=
None
if
self
.
D_has_hdim
else
self
.
headdim
,
ngroups
=
self
.
ngroups
,
norm_before_gate
=
self
.
norm_before_gate
,
**
dt_limit_kwargs
,
)
if
seqlen_og
is
not
None
:
out
=
rearrange
(
out
,
"b l d -> (b l) d"
)
if
self
.
process_group
is
not
None
:
reduce_fn
=
reduce_scatter
if
self
.
sequence_parallel
else
all_reduce
out
=
reduce_fn
(
out
,
self
.
process_group
)
else
:
d_mlp
=
(
zxbcdt
.
shape
[
-
1
]
-
2
*
self
.
d_ssm
-
2
*
self
.
ngroups
*
self
.
d_state
-
self
.
nheads
)
//
2
z0
,
x0
,
z
,
xBC
,
dt
=
torch
.
split
(
zxbcdt
,
[
d_mlp
,
d_mlp
,
self
.
d_ssm
,
self
.
d_ssm
+
2
*
self
.
ngroups
*
self
.
d_state
,
self
.
nheads
],
dim
=-
1
)
if
conv_state
is
not
None
:
if
cu_seqlens
is
None
:
# If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
xBC_t
=
rearrange
(
xBC
,
"b l d -> b d l"
)
conv_state
.
copy_
(
F
.
pad
(
xBC_t
,
(
self
.
d_conv
-
xBC_t
.
shape
[
-
1
],
0
)))
# Update state (B D W)
else
:
assert
causal_conv1d_varlen_states
is
not
None
,
"varlen inference requires causal_conv1d package"
assert
batch
==
1
,
"varlen inference only supports batch dimension 1"
conv_varlen_states
=
causal_conv1d_varlen_states
(
xBC
.
squeeze
(
0
),
cu_seqlens
,
state_len
=
conv_state
.
shape
[
-
1
]
)
conv_state
.
copy_
(
conv_varlen_states
)
assert
self
.
activation
in
[
"silu"
,
"swish"
]
if
causal_conv1d_fn
is
None
or
self
.
activation
not
in
[
"silu"
,
"swish"
]:
assert
seq_idx
is
None
,
"varlen conv1d requires the causal_conv1d package"
xBC
=
self
.
act
(
self
.
conv1d
(
xBC
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
)[:,
-
(
self
.
dconv
-
1
):]
)
# (B, L, self.d_ssm + 2 * ngroups * d_state)
else
:
xBC
=
causal_conv1d_fn
(
xBC
.
transpose
(
1
,
2
),
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
bias
=
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
seq_idx
=
seq_idx
,
).
transpose
(
1
,
2
)
x
,
B
,
C
=
torch
.
split
(
xBC
,
[
self
.
d_ssm
,
self
.
ngroups
*
self
.
d_state
,
self
.
ngroups
*
self
.
d_state
],
dim
=-
1
)
y
=
mamba_chunk_scan_combined
(
rearrange
(
x
,
"b l (h p) -> b l h p"
,
p
=
self
.
headdim
),
dt
,
A
,
rearrange
(
B
,
"b l (g n) -> b l g n"
,
g
=
self
.
ngroups
),
rearrange
(
C
,
"b l (g n) -> b l g n"
,
g
=
self
.
ngroups
),
chunk_size
=
self
.
chunk_size
,
D
=
rearrange
(
self
.
D
,
"(h p) -> h p"
,
p
=
self
.
headdim
)
if
self
.
D_has_hdim
else
self
.
D
,
z
=
rearrange
(
z
,
"b l (h p) -> b l h p"
,
p
=
self
.
headdim
)
if
not
self
.
rmsnorm
else
None
,
dt_bias
=
self
.
dt_bias
,
dt_softplus
=
True
,
seq_idx
=
seq_idx
,
cu_seqlens
=
cu_seqlens
,
**
dt_limit_kwargs
,
return_final_states
=
ssm_state
is
not
None
,
return_varlen_states
=
cu_seqlens
is
not
None
and
inference_params
is
not
None
,
)
if
ssm_state
is
not
None
:
y
,
last_state
,
*
rest
=
y
if
cu_seqlens
is
None
:
ssm_state
.
copy_
(
last_state
)
else
:
varlen_states
=
rest
[
0
]
ssm_state
.
copy_
(
varlen_states
)
y
=
rearrange
(
y
,
"b l h p -> b l (h p)"
)
if
self
.
rmsnorm
:
y
=
self
.
norm
(
y
,
z
)
if
d_mlp
>
0
:
y
=
torch
.
cat
([
F
.
silu
(
z0
)
*
x0
,
y
],
dim
=-
1
)
if
seqlen_og
is
not
None
:
y
=
rearrange
(
y
,
"b l d -> (b l) d"
)
out
=
self
.
out_proj
(
y
)
return
out
def
step
(
self
,
hidden_states
,
conv_state
,
ssm_state
):
dtype
=
hidden_states
.
dtype
assert
hidden_states
.
shape
[
1
]
==
1
,
"Only support decoding with 1 token at a time for now"
zxbcdt
=
self
.
in_proj
(
hidden_states
.
squeeze
(
1
))
# (B 2D)
d_mlp
=
(
zxbcdt
.
shape
[
-
1
]
-
2
*
self
.
d_ssm
-
2
*
self
.
ngroups
*
self
.
d_state
-
self
.
nheads
)
//
2
z0
,
x0
,
z
,
xBC
,
dt
=
torch
.
split
(
zxbcdt
,
[
d_mlp
,
d_mlp
,
self
.
d_ssm
,
self
.
d_ssm
+
2
*
self
.
ngroups
*
self
.
d_state
,
self
.
nheads
],
dim
=-
1
)
# Conv step
if
causal_conv1d_update
is
None
:
conv_state
.
copy_
(
torch
.
roll
(
conv_state
,
shifts
=-
1
,
dims
=-
1
))
# Update state (B D W)
conv_state
[:,
:,
-
1
]
=
xBC
xBC
=
torch
.
sum
(
conv_state
*
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
dim
=-
1
)
# (B D)
if
self
.
conv1d
.
bias
is
not
None
:
xBC
=
xBC
+
self
.
conv1d
.
bias
xBC
=
self
.
act
(
xBC
).
to
(
dtype
=
dtype
)
else
:
xBC
=
causal_conv1d_update
(
xBC
,
conv_state
,
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
self
.
conv1d
.
bias
,
self
.
activation
,
)
x
,
B
,
C
=
torch
.
split
(
xBC
,
[
self
.
d_ssm
,
self
.
ngroups
*
self
.
d_state
,
self
.
ngroups
*
self
.
d_state
],
dim
=-
1
)
A
=
-
torch
.
exp
(
self
.
A_log
.
float
())
# (nheads,)
# SSM step
if
selective_state_update
is
None
:
assert
self
.
ngroups
==
1
,
"Only support ngroups=1 for this inference code path"
# Discretize A and B
dt
=
F
.
softplus
(
dt
+
self
.
dt_bias
.
to
(
dtype
=
dt
.
dtype
))
# (batch, nheads)
dA
=
torch
.
exp
(
dt
*
A
)
# (batch, nheads)
x
=
rearrange
(
x
,
"b (h p) -> b h p"
,
p
=
self
.
headdim
)
dBx
=
torch
.
einsum
(
"bh,bn,bhp->bhpn"
,
dt
,
B
,
x
)
ssm_state
.
copy_
(
ssm_state
*
rearrange
(
dA
,
"b h -> b h 1 1"
)
+
dBx
)
y
=
torch
.
einsum
(
"bhpn,bn->bhp"
,
ssm_state
.
to
(
dtype
),
C
)
y
=
y
+
rearrange
(
self
.
D
.
to
(
dtype
),
"h -> h 1"
)
*
x
y
=
rearrange
(
y
,
"b h p -> b (h p)"
)
if
not
self
.
rmsnorm
:
y
=
y
*
self
.
act
(
z
)
# (B D)
else
:
A
=
repeat
(
A
,
"h -> h p n"
,
p
=
self
.
headdim
,
n
=
self
.
d_state
).
to
(
dtype
=
torch
.
float32
)
dt
=
repeat
(
dt
,
"b h -> b h p"
,
p
=
self
.
headdim
)
dt_bias
=
repeat
(
self
.
dt_bias
,
"h -> h p"
,
p
=
self
.
headdim
)
D
=
repeat
(
self
.
D
,
"h -> h p"
,
p
=
self
.
headdim
)
B
=
rearrange
(
B
,
"b (g n) -> b g n"
,
g
=
self
.
ngroups
)
C
=
rearrange
(
C
,
"b (g n) -> b g n"
,
g
=
self
.
ngroups
)
x_reshaped
=
rearrange
(
x
,
"b (h p) -> b h p"
,
p
=
self
.
headdim
)
if
not
self
.
rmsnorm
:
z
=
rearrange
(
z
,
"b (h p) -> b h p"
,
p
=
self
.
headdim
)
y
=
selective_state_update
(
ssm_state
,
x_reshaped
,
dt
,
A
,
B
,
C
,
D
,
z
=
z
if
not
self
.
rmsnorm
else
None
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
)
y
=
rearrange
(
y
,
"b h p -> b (h p)"
)
if
self
.
rmsnorm
:
y
=
self
.
norm
(
y
,
z
)
if
d_mlp
>
0
:
y
=
torch
.
cat
([
F
.
silu
(
z0
)
*
x0
,
y
],
dim
=-
1
)
out
=
self
.
out_proj
(
y
)
return
out
.
unsqueeze
(
1
),
conv_state
,
ssm_state
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
device
=
self
.
out_proj
.
weight
.
device
conv_dtype
=
self
.
conv1d
.
weight
.
dtype
if
dtype
is
None
else
dtype
conv_state
=
torch
.
zeros
(
batch_size
,
self
.
d_conv
,
self
.
conv1d
.
weight
.
shape
[
0
],
device
=
device
,
dtype
=
conv_dtype
).
transpose
(
1
,
2
)
ssm_dtype
=
self
.
in_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
ssm_state
=
torch
.
zeros
(
batch_size
,
self
.
nheads
,
self
.
headdim
,
self
.
d_state
,
device
=
device
,
dtype
=
ssm_dtype
)
return
conv_state
,
ssm_state
def
_get_states_from_cache
(
self
,
inference_params
,
batch_size
,
initialize_states
=
False
):
assert
self
.
layer_idx
is
not
None
if
self
.
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
batch_shape
=
(
batch_size
,)
conv_state
=
torch
.
zeros
(
batch_size
,
self
.
d_conv
,
self
.
conv1d
.
weight
.
shape
[
0
],
device
=
self
.
conv1d
.
weight
.
device
,
dtype
=
self
.
conv1d
.
weight
.
dtype
,
).
transpose
(
1
,
2
)
ssm_state
=
torch
.
zeros
(
batch_size
,
self
.
nheads
,
self
.
headdim
,
self
.
d_state
,
device
=
self
.
in_proj
.
weight
.
device
,
dtype
=
self
.
in_proj
.
weight
.
dtype
,
)
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
(
conv_state
,
ssm_state
)
else
:
conv_state
,
ssm_state
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
# TODO: What if batch size changes between generation, and we reuse the same states?
if
initialize_states
:
conv_state
.
zero_
()
ssm_state
.
zero_
()
return
conv_state
,
ssm_state
mamba/mamba_ssm/modules/mamba2_simple.py
0 → 100644
View file @
2eefe3d6
# Copyright (c) 2024, Tri Dao, Albert Gu.
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
try
:
from
causal_conv1d
import
causal_conv1d_fn
except
ImportError
:
causal_conv1d_fn
=
None
try
:
from
mamba_ssm.ops.triton.layernorm_gated
import
RMSNorm
as
RMSNormGated
,
LayerNorm
except
ImportError
:
RMSNormGated
,
LayerNorm
=
None
,
None
from
mamba_ssm.ops.triton.ssd_combined
import
mamba_chunk_scan_combined
from
mamba_ssm.ops.triton.ssd_combined
import
mamba_split_conv1d_scan_combined
class
Mamba2Simple
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_state
=
64
,
d_conv
=
4
,
conv_init
=
None
,
expand
=
2
,
headdim
=
128
,
ngroups
=
1
,
A_init_range
=
(
1
,
16
),
dt_min
=
0.001
,
dt_max
=
0.1
,
dt_init_floor
=
1e-4
,
dt_limit
=
(
0.0
,
float
(
"inf"
)),
learnable_init_states
=
False
,
activation
=
"swish"
,
bias
=
False
,
conv_bias
=
True
,
# Fused kernel and sharding options
chunk_size
=
256
,
use_mem_eff_path
=
True
,
layer_idx
=
None
,
# Absorb kwarg for general module
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
d_model
=
d_model
self
.
d_state
=
d_state
self
.
d_conv
=
d_conv
self
.
conv_init
=
conv_init
self
.
expand
=
expand
self
.
d_inner
=
self
.
expand
*
self
.
d_model
self
.
headdim
=
headdim
self
.
ngroups
=
ngroups
assert
self
.
d_inner
%
self
.
headdim
==
0
self
.
nheads
=
self
.
d_inner
//
self
.
headdim
self
.
dt_limit
=
dt_limit
self
.
learnable_init_states
=
learnable_init_states
self
.
activation
=
activation
self
.
chunk_size
=
chunk_size
self
.
use_mem_eff_path
=
use_mem_eff_path
self
.
layer_idx
=
layer_idx
# Order: [z, x, B, C, dt]
d_in_proj
=
2
*
self
.
d_inner
+
2
*
self
.
ngroups
*
self
.
d_state
+
self
.
nheads
self
.
in_proj
=
nn
.
Linear
(
self
.
d_model
,
d_in_proj
,
bias
=
bias
,
**
factory_kwargs
)
conv_dim
=
self
.
d_inner
+
2
*
self
.
ngroups
*
self
.
d_state
self
.
conv1d
=
nn
.
Conv1d
(
in_channels
=
conv_dim
,
out_channels
=
conv_dim
,
bias
=
conv_bias
,
kernel_size
=
d_conv
,
groups
=
conv_dim
,
padding
=
d_conv
-
1
,
**
factory_kwargs
,
)
if
self
.
conv_init
is
not
None
:
nn
.
init
.
uniform_
(
self
.
conv1d
.
weight
,
-
self
.
conv_init
,
self
.
conv_init
)
# self.conv1d.weight._no_weight_decay = True
if
self
.
learnable_init_states
:
self
.
init_states
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
nheads
,
self
.
headdim
,
self
.
d_state
,
**
factory_kwargs
))
self
.
init_states
.
_no_weight_decay
=
True
self
.
act
=
nn
.
SiLU
()
# Initialize log dt bias
dt
=
torch
.
exp
(
torch
.
rand
(
self
.
nheads
,
**
factory_kwargs
)
*
(
math
.
log
(
dt_max
)
-
math
.
log
(
dt_min
))
+
math
.
log
(
dt_min
)
)
dt
=
torch
.
clamp
(
dt
,
min
=
dt_init_floor
)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt
=
dt
+
torch
.
log
(
-
torch
.
expm1
(
-
dt
))
self
.
dt_bias
=
nn
.
Parameter
(
inv_dt
)
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self
.
dt_bias
.
_no_weight_decay
=
True
# A parameter
assert
A_init_range
[
0
]
>
0
and
A_init_range
[
1
]
>=
A_init_range
[
0
]
A
=
torch
.
empty
(
self
.
nheads
,
dtype
=
torch
.
float32
,
device
=
device
).
uniform_
(
*
A_init_range
)
A_log
=
torch
.
log
(
A
).
to
(
dtype
=
dtype
)
self
.
A_log
=
nn
.
Parameter
(
A_log
)
# self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
self
.
A_log
.
_no_weight_decay
=
True
# D "skip" parameter
self
.
D
=
nn
.
Parameter
(
torch
.
ones
(
self
.
nheads
,
device
=
device
))
self
.
D
.
_no_weight_decay
=
True
# Extra normalization layer right before output projection
assert
RMSNormGated
is
not
None
self
.
norm
=
RMSNormGated
(
self
.
d_inner
,
eps
=
1e-5
,
norm_before_gate
=
False
,
**
factory_kwargs
)
self
.
out_proj
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
,
bias
=
bias
,
**
factory_kwargs
)
def
forward
(
self
,
u
,
seq_idx
=
None
):
"""
u: (B, L, D)
Returns: same shape as u
"""
batch
,
seqlen
,
dim
=
u
.
shape
zxbcdt
=
self
.
in_proj
(
u
)
# (B, L, d_in_proj)
A
=
-
torch
.
exp
(
self
.
A_log
)
# (nheads) or (d_inner, d_state)
initial_states
=
repeat
(
self
.
init_states
,
"... -> b ..."
,
b
=
batch
)
if
self
.
learnable_init_states
else
None
dt_limit_kwargs
=
{}
if
self
.
dt_limit
==
(
0.0
,
float
(
"inf"
))
else
dict
(
dt_limit
=
self
.
dt_limit
)
if
self
.
use_mem_eff_path
:
# Fully fused path
out
=
mamba_split_conv1d_scan_combined
(
zxbcdt
,
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
self
.
conv1d
.
bias
,
self
.
dt_bias
,
A
,
D
=
self
.
D
,
chunk_size
=
self
.
chunk_size
,
seq_idx
=
seq_idx
,
activation
=
self
.
activation
,
rmsnorm_weight
=
self
.
norm
.
weight
,
rmsnorm_eps
=
self
.
norm
.
eps
,
outproj_weight
=
self
.
out_proj
.
weight
,
outproj_bias
=
self
.
out_proj
.
bias
,
headdim
=
self
.
headdim
,
ngroups
=
self
.
ngroups
,
norm_before_gate
=
False
,
initial_states
=
initial_states
,
**
dt_limit_kwargs
,
)
else
:
z
,
xBC
,
dt
=
torch
.
split
(
zxbcdt
,
[
self
.
d_inner
,
self
.
d_inner
+
2
*
self
.
ngroups
*
self
.
d_state
,
self
.
nheads
],
dim
=-
1
)
dt
=
F
.
softplus
(
dt
+
self
.
dt_bias
)
# (B, L, nheads)
assert
self
.
activation
in
[
"silu"
,
"swish"
]
# 1D Convolution
if
causal_conv1d_fn
is
None
or
self
.
activation
not
in
[
"silu"
,
"swish"
]:
xBC
=
self
.
act
(
self
.
conv1d
(
xBC
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
)
)
# (B, L, self.d_inner + 2 * ngroups * d_state)
xBC
=
xBC
[:,
:
seqlen
,
:]
else
:
xBC
=
causal_conv1d_fn
(
x
=
xBC
.
transpose
(
1
,
2
),
weight
=
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
bias
=
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
).
transpose
(
1
,
2
)
# Split into 3 main branches: X, B, C
# These correspond to V, K, Q respectively in the SSM/attention duality
x
,
B
,
C
=
torch
.
split
(
xBC
,
[
self
.
d_inner
,
self
.
ngroups
*
self
.
d_state
,
self
.
ngroups
*
self
.
d_state
],
dim
=-
1
)
y
=
mamba_chunk_scan_combined
(
rearrange
(
x
,
"b l (h p) -> b l h p"
,
p
=
self
.
headdim
),
dt
,
A
,
rearrange
(
B
,
"b l (g n) -> b l g n"
,
g
=
self
.
ngroups
),
rearrange
(
C
,
"b l (g n) -> b l g n"
,
g
=
self
.
ngroups
),
chunk_size
=
self
.
chunk_size
,
D
=
self
.
D
,
z
=
None
,
seq_idx
=
seq_idx
,
initial_states
=
initial_states
,
**
dt_limit_kwargs
,
)
y
=
rearrange
(
y
,
"b l h p -> b l (h p)"
)
# Multiply "gate" branch and apply extra normalization layer
y
=
self
.
norm
(
y
,
z
)
out
=
self
.
out_proj
(
y
)
return
out
mamba/mamba_ssm/modules/mamba_simple.py
0 → 100644
View file @
2eefe3d6
# Copyright (c) 2023, Tri Dao, Albert Gu.
import
math
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
einops
import
rearrange
,
repeat
from
mamba_ssm.ops.selective_scan_interface
import
selective_scan_fn
,
mamba_inner_fn
try
:
from
causal_conv1d
import
causal_conv1d_fn
,
causal_conv1d_update
except
ImportError
:
causal_conv1d_fn
,
causal_conv1d_update
=
None
,
None
try
:
from
mamba_ssm.ops.triton.selective_state_update
import
selective_state_update
except
ImportError
:
selective_state_update
=
None
try
:
from
mamba_ssm.ops.triton.layer_norm
import
RMSNorm
,
layer_norm_fn
,
rms_norm_fn
except
ImportError
:
RMSNorm
,
layer_norm_fn
,
rms_norm_fn
=
None
,
None
,
None
class
Mamba
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_state
=
16
,
d_conv
=
4
,
expand
=
2
,
dt_rank
=
"auto"
,
dt_min
=
0.001
,
dt_max
=
0.1
,
dt_init
=
"random"
,
dt_scale
=
1.0
,
dt_init_floor
=
1e-4
,
conv_bias
=
True
,
bias
=
False
,
use_fast_path
=
True
,
# Fused kernel options
layer_idx
=
None
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
d_model
=
d_model
self
.
d_state
=
d_state
self
.
d_conv
=
d_conv
self
.
expand
=
expand
self
.
d_inner
=
int
(
self
.
expand
*
self
.
d_model
)
self
.
dt_rank
=
math
.
ceil
(
self
.
d_model
/
16
)
if
dt_rank
==
"auto"
else
dt_rank
self
.
use_fast_path
=
use_fast_path
self
.
layer_idx
=
layer_idx
self
.
in_proj
=
nn
.
Linear
(
self
.
d_model
,
self
.
d_inner
*
2
,
bias
=
bias
,
**
factory_kwargs
)
self
.
conv1d
=
nn
.
Conv1d
(
in_channels
=
self
.
d_inner
,
out_channels
=
self
.
d_inner
,
bias
=
conv_bias
,
kernel_size
=
d_conv
,
groups
=
self
.
d_inner
,
padding
=
d_conv
-
1
,
**
factory_kwargs
,
)
self
.
activation
=
"silu"
self
.
act
=
nn
.
SiLU
()
self
.
x_proj
=
nn
.
Linear
(
self
.
d_inner
,
self
.
dt_rank
+
self
.
d_state
*
2
,
bias
=
False
,
**
factory_kwargs
)
self
.
dt_proj
=
nn
.
Linear
(
self
.
dt_rank
,
self
.
d_inner
,
bias
=
True
,
**
factory_kwargs
)
# Initialize special dt projection to preserve variance at initialization
dt_init_std
=
self
.
dt_rank
**-
0.5
*
dt_scale
if
dt_init
==
"constant"
:
nn
.
init
.
constant_
(
self
.
dt_proj
.
weight
,
dt_init_std
)
elif
dt_init
==
"random"
:
nn
.
init
.
uniform_
(
self
.
dt_proj
.
weight
,
-
dt_init_std
,
dt_init_std
)
else
:
raise
NotImplementedError
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt
=
torch
.
exp
(
torch
.
rand
(
self
.
d_inner
,
**
factory_kwargs
)
*
(
math
.
log
(
dt_max
)
-
math
.
log
(
dt_min
))
+
math
.
log
(
dt_min
)
).
clamp
(
min
=
dt_init_floor
)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt
=
dt
+
torch
.
log
(
-
torch
.
expm1
(
-
dt
))
with
torch
.
no_grad
():
self
.
dt_proj
.
bias
.
copy_
(
inv_dt
)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
self
.
dt_proj
.
bias
.
_no_reinit
=
True
# S4D real initialization
A
=
repeat
(
torch
.
arange
(
1
,
self
.
d_state
+
1
,
dtype
=
torch
.
float32
,
device
=
device
),
"n -> d n"
,
d
=
self
.
d_inner
,
).
contiguous
()
A_log
=
torch
.
log
(
A
)
# Keep A_log in fp32
self
.
A_log
=
nn
.
Parameter
(
A_log
)
self
.
A_log
.
_no_weight_decay
=
True
# D "skip" parameter
self
.
D
=
nn
.
Parameter
(
torch
.
ones
(
self
.
d_inner
,
device
=
device
))
# Keep in fp32
self
.
D
.
_no_weight_decay
=
True
self
.
out_proj
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
,
bias
=
bias
,
**
factory_kwargs
)
def
forward
(
self
,
hidden_states
,
inference_params
=
None
):
"""
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
batch
,
seqlen
,
dim
=
hidden_states
.
shape
conv_state
,
ssm_state
=
None
,
None
if
inference_params
is
not
None
:
conv_state
,
ssm_state
=
self
.
_get_states_from_cache
(
inference_params
,
batch
)
if
inference_params
.
seqlen_offset
>
0
:
# The states are updated inplace
out
,
_
,
_
=
self
.
step
(
hidden_states
,
conv_state
,
ssm_state
)
return
out
# We do matmul and transpose BLH -> HBL at the same time
xz
=
rearrange
(
self
.
in_proj
.
weight
@
rearrange
(
hidden_states
,
"b l d -> d (b l)"
),
"d (b l) -> b d l"
,
l
=
seqlen
,
)
if
self
.
in_proj
.
bias
is
not
None
:
xz
=
xz
+
rearrange
(
self
.
in_proj
.
bias
.
to
(
dtype
=
xz
.
dtype
),
"d -> d 1"
)
A
=
-
torch
.
exp
(
self
.
A_log
.
float
())
# (d_inner, d_state)
# In the backward pass we write dx and dz next to each other to avoid torch.cat
if
self
.
use_fast_path
and
causal_conv1d_fn
is
not
None
and
inference_params
is
None
:
# Doesn't support outputting the states
out
=
mamba_inner_fn
(
xz
,
self
.
conv1d
.
weight
,
self
.
conv1d
.
bias
,
self
.
x_proj
.
weight
,
self
.
dt_proj
.
weight
,
self
.
out_proj
.
weight
,
self
.
out_proj
.
bias
,
A
,
None
,
# input-dependent B
None
,
# input-dependent C
self
.
D
.
float
(),
delta_bias
=
self
.
dt_proj
.
bias
.
float
(),
delta_softplus
=
True
,
)
else
:
x
,
z
=
xz
.
chunk
(
2
,
dim
=
1
)
# Compute short convolution
if
conv_state
is
not
None
:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
conv_state
.
copy_
(
F
.
pad
(
x
,
(
self
.
d_conv
-
x
.
shape
[
-
1
],
0
)))
# Update state (B D W)
if
causal_conv1d_fn
is
None
:
x
=
self
.
act
(
self
.
conv1d
(
x
)[...,
:
seqlen
])
else
:
assert
self
.
activation
in
[
"silu"
,
"swish"
]
x
=
causal_conv1d_fn
(
x
=
x
,
weight
=
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
bias
=
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
)
# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl
=
self
.
x_proj
(
rearrange
(
x
,
"b d l -> (b l) d"
))
# (bl d)
dt
,
B
,
C
=
torch
.
split
(
x_dbl
,
[
self
.
dt_rank
,
self
.
d_state
,
self
.
d_state
],
dim
=-
1
)
dt
=
self
.
dt_proj
.
weight
@
dt
.
t
()
dt
=
rearrange
(
dt
,
"d (b l) -> b d l"
,
l
=
seqlen
)
B
=
rearrange
(
B
,
"(b l) dstate -> b dstate l"
,
l
=
seqlen
).
contiguous
()
C
=
rearrange
(
C
,
"(b l) dstate -> b dstate l"
,
l
=
seqlen
).
contiguous
()
assert
self
.
activation
in
[
"silu"
,
"swish"
]
y
=
selective_scan_fn
(
x
,
dt
,
A
,
B
,
C
,
self
.
D
.
float
(),
z
=
z
,
delta_bias
=
self
.
dt_proj
.
bias
.
float
(),
delta_softplus
=
True
,
return_last_state
=
ssm_state
is
not
None
,
)
if
ssm_state
is
not
None
:
y
,
last_state
=
y
ssm_state
.
copy_
(
last_state
)
y
=
rearrange
(
y
,
"b d l -> b l d"
)
out
=
self
.
out_proj
(
y
)
return
out
def
step
(
self
,
hidden_states
,
conv_state
,
ssm_state
):
dtype
=
hidden_states
.
dtype
assert
hidden_states
.
shape
[
1
]
==
1
,
"Only support decoding with 1 token at a time for now"
xz
=
self
.
in_proj
(
hidden_states
.
squeeze
(
1
))
# (B 2D)
x
,
z
=
xz
.
chunk
(
2
,
dim
=-
1
)
# (B D)
# Conv step
if
causal_conv1d_update
is
None
:
conv_state
.
copy_
(
torch
.
roll
(
conv_state
,
shifts
=-
1
,
dims
=-
1
))
# Update state (B D W)
conv_state
[:,
:,
-
1
]
=
x
x
=
torch
.
sum
(
conv_state
*
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
dim
=-
1
)
# (B D)
if
self
.
conv1d
.
bias
is
not
None
:
x
=
x
+
self
.
conv1d
.
bias
x
=
self
.
act
(
x
).
to
(
dtype
=
dtype
)
else
:
x
=
causal_conv1d_update
(
x
,
conv_state
,
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
self
.
conv1d
.
bias
,
self
.
activation
,
)
x_db
=
self
.
x_proj
(
x
)
# (B dt_rank+2*d_state)
dt
,
B
,
C
=
torch
.
split
(
x_db
,
[
self
.
dt_rank
,
self
.
d_state
,
self
.
d_state
],
dim
=-
1
)
# Don't add dt_bias here
dt
=
F
.
linear
(
dt
,
self
.
dt_proj
.
weight
)
# (B d_inner)
A
=
-
torch
.
exp
(
self
.
A_log
.
float
())
# (d_inner, d_state)
# SSM step
if
selective_state_update
is
None
:
# Discretize A and B
dt
=
F
.
softplus
(
dt
+
self
.
dt_proj
.
bias
.
to
(
dtype
=
dt
.
dtype
))
dA
=
torch
.
exp
(
torch
.
einsum
(
"bd,dn->bdn"
,
dt
,
A
))
dB
=
torch
.
einsum
(
"bd,bn->bdn"
,
dt
,
B
)
ssm_state
.
copy_
(
ssm_state
*
dA
+
rearrange
(
x
,
"b d -> b d 1"
)
*
dB
)
y
=
torch
.
einsum
(
"bdn,bn->bd"
,
ssm_state
.
to
(
dtype
),
C
)
y
=
y
+
self
.
D
.
to
(
dtype
)
*
x
y
=
y
*
self
.
act
(
z
)
# (B D)
else
:
y
=
selective_state_update
(
ssm_state
,
x
,
dt
,
A
,
B
,
C
,
self
.
D
,
z
=
z
,
dt_bias
=
self
.
dt_proj
.
bias
,
dt_softplus
=
True
)
out
=
self
.
out_proj
(
y
)
return
out
.
unsqueeze
(
1
),
conv_state
,
ssm_state
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
device
=
self
.
out_proj
.
weight
.
device
conv_dtype
=
self
.
conv1d
.
weight
.
dtype
if
dtype
is
None
else
dtype
conv_state
=
torch
.
zeros
(
batch_size
,
self
.
d_model
*
self
.
expand
,
self
.
d_conv
,
device
=
device
,
dtype
=
conv_dtype
)
ssm_dtype
=
self
.
dt_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
# ssm_dtype = torch.float32
ssm_state
=
torch
.
zeros
(
batch_size
,
self
.
d_model
*
self
.
expand
,
self
.
d_state
,
device
=
device
,
dtype
=
ssm_dtype
)
return
conv_state
,
ssm_state
def
_get_states_from_cache
(
self
,
inference_params
,
batch_size
,
initialize_states
=
False
):
assert
self
.
layer_idx
is
not
None
if
self
.
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
batch_shape
=
(
batch_size
,)
conv_state
=
torch
.
zeros
(
batch_size
,
self
.
d_model
*
self
.
expand
,
self
.
d_conv
,
device
=
self
.
conv1d
.
weight
.
device
,
dtype
=
self
.
conv1d
.
weight
.
dtype
,
)
ssm_state
=
torch
.
zeros
(
batch_size
,
self
.
d_model
*
self
.
expand
,
self
.
d_state
,
device
=
self
.
dt_proj
.
weight
.
device
,
dtype
=
self
.
dt_proj
.
weight
.
dtype
,
# dtype=torch.float32,
)
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
(
conv_state
,
ssm_state
)
else
:
conv_state
,
ssm_state
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
# TODO: What if batch size changes between generation, and we reuse the same states?
if
initialize_states
:
conv_state
.
zero_
()
ssm_state
.
zero_
()
return
conv_state
,
ssm_state
mamba/mamba_ssm/modules/mha.py
0 → 100644
View file @
2eefe3d6
# Copyright (c) 2024, Tri Dao, Albert Gu.
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
try
:
from
flash_attn
import
flash_attn_with_kvcache
except
ImportError
:
flash_attn_with_kvcache
=
None
try
:
from
flash_attn.layers.rotary
import
RotaryEmbedding
except
ImportError
:
RotaryEmbedding
=
None
try
:
from
causal_conv1d
import
causal_conv1d_fn
,
causal_conv1d_update
except
ImportError
:
causal_conv1d_fn
,
causal_conv1d_update
=
None
,
None
def
_update_kv_cache
(
kv
,
inference_params
,
layer_idx
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
# Pre-allocate memory for key-values for inference.
num_heads
,
head_dim
=
kv
.
shape
[
-
2
:]
assert
layer_idx
in
inference_params
.
key_value_memory_dict
kv_cache
,
_
=
inference_params
.
key_value_memory_dict
[
layer_idx
]
# Adjust key and value for inference
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
kv
.
shape
[
0
]
sequence_start
=
inference_params
.
seqlen_offset
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
assert
batch_end
<=
kv_cache
.
shape
[
0
]
assert
sequence_end
<=
kv_cache
.
shape
[
1
]
assert
kv_cache
is
not
None
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
return
kv_cache
[
batch_start
:
batch_end
,
:
sequence_end
,
...]
class
MHA
(
nn
.
Module
):
"""Multi-head self-attention and cross-attention"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
num_heads_kv
=
None
,
head_dim
=
None
,
# If None, use embed_dim // num_heads
mlp_dim
=
0
,
qkv_proj_bias
=
True
,
out_proj_bias
=
True
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
d_conv
=
0
,
rotary_emb_dim
=
0
,
rotary_emb_base
=
10000.0
,
rotary_emb_interleaved
=
False
,
device
=
None
,
dtype
=
None
,
)
->
None
:
"""
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
layer_idx
=
layer_idx
self
.
d_conv
=
d_conv
self
.
rotary_emb_dim
=
rotary_emb_dim
self
.
softmax_scale
=
softmax_scale
self
.
causal
=
causal
self
.
num_heads
=
num_heads
self
.
num_heads_kv
=
num_heads_kv
if
num_heads_kv
is
not
None
else
num_heads
assert
(
self
.
num_heads
%
self
.
num_heads_kv
==
0
),
"num_heads must be divisible by num_heads_kv"
if
head_dim
is
None
:
assert
self
.
embed_dim
%
num_heads
==
0
,
"embed_dim must be divisible by num_heads"
self
.
head_dim
=
head_dim
if
head_dim
is
not
None
else
self
.
embed_dim
//
num_heads
self
.
mlp_dim
=
math
.
ceil
(
mlp_dim
/
256
)
*
256
qkv_dim
=
self
.
head_dim
*
(
self
.
num_heads
+
2
*
self
.
num_heads_kv
)
out_dim
=
self
.
head_dim
*
self
.
num_heads
if
self
.
rotary_emb_dim
>
0
:
assert
RotaryEmbedding
is
not
None
,
"rotary requires flash_attn to be installed"
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
base
=
rotary_emb_base
,
interleaved
=
rotary_emb_interleaved
,
device
=
device
,
)
self
.
in_proj
=
nn
.
Linear
(
embed_dim
,
qkv_dim
+
self
.
mlp_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
if
self
.
d_conv
>
0
:
self
.
conv1d
=
nn
.
Conv1d
(
qkv_dim
,
qkv_dim
,
kernel_size
=
self
.
d_conv
,
padding
=
self
.
d_conv
-
1
,
groups
=
qkv_dim
,
**
factory_kwargs
)
self
.
out_proj
=
nn
.
Linear
(
out_dim
+
self
.
mlp_dim
//
2
,
embed_dim
,
bias
=
out_proj_bias
,
**
factory_kwargs
)
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
):
dtype
=
self
.
out_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
device
=
self
.
out_proj
.
weight
.
device
if
self
.
d_conv
>
0
:
conv_state
=
torch
.
zeros
(
batch_size
,
self
.
conv1d
.
weight
.
shape
[
0
],
self
.
d_conv
,
device
=
device
,
dtype
=
dtype
)
else
:
conv_state
=
None
kv_cache
=
torch
.
empty
(
batch_size
,
max_seqlen
,
2
,
self
.
num_heads_kv
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
,
)
return
kv_cache
,
conv_state
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
assert
self
.
layer_idx
is
not
None
,
"Generation requires layer_idx in the constructor"
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
def
_apply_rotary_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
"""
assert
inference_params
is
not
None
and
inference_params
.
seqlen_offset
>
0
if
self
.
rotary_emb_dim
>
0
:
self
.
rotary_emb
.
_update_cos_sin_cache
(
inference_params
.
max_seqlen
,
device
=
q
.
device
,
dtype
=
q
.
dtype
)
rotary_cos
,
rotary_sin
=
self
.
rotary_emb
.
_cos_cached
,
self
.
rotary_emb
.
_sin_cached
else
:
rotary_cos
,
rotary_sin
=
None
,
None
batch
=
q
.
shape
[
0
]
kv_cache
,
_
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
kv_cache
=
kv_cache
[:
batch
]
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seqlen_offset
)
assert
flash_attn_with_kvcache
is
not
None
,
"flash_attn must be installed"
context
=
flash_attn_with_kvcache
(
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
1
],
kv
[:,
:,
0
],
kv
[:,
:,
1
],
rotary_cos
=
rotary_cos
,
rotary_sin
=
rotary_sin
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
,
rotary_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
)
return
context
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""Write kv to inference_params, then do attention"""
if
(
inference_params
.
seqlen_offset
==
0
or
flash_attn_with_kvcache
is
None
):
# TODO: this only uses seqlen_offset and not lengths_per_sample.
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
k
,
v
=
kv
.
unbind
(
dim
=-
3
)
k
=
torch
.
repeat_interleave
(
k
,
dim
=
2
,
repeats
=
self
.
num_heads
//
self
.
num_heads_kv
)
v
=
torch
.
repeat_interleave
(
v
,
dim
=
2
,
repeats
=
self
.
num_heads
//
self
.
num_heads_kv
)
return
F
.
scaled_dot_product_attention
(
q
.
transpose
(
1
,
2
),
k
.
transpose
(
1
,
2
),
v
.
transpose
(
1
,
2
),
is_causal
=
self
.
causal
,
scale
=
self
.
softmax_scale
).
transpose
(
1
,
2
)
else
:
batch
=
q
.
shape
[
0
]
kv_cache
,
_
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
kv_cache
=
kv_cache
[:
batch
]
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seqlen_offset
)
return
flash_attn_with_kvcache
(
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
1
],
kv
[:,
:,
0
],
kv
[:,
:,
1
],
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
,
)
def
forward
(
self
,
x
,
inference_params
=
None
):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
is the is the sum of the sequence lengths in the batch.
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
"""
if
inference_params
is
not
None
and
self
.
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
self
.
allocate_inference_cache
(
x
.
shape
[
0
],
inference_params
.
max_seqlen
,
dtype
=
x
.
dtype
)
seqlen_offset
=
(
0
if
inference_params
is
None
else
(
inference_params
.
lengths_per_sample
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seqlen_offset
)
)
rotary_max_seqlen
=
inference_params
.
max_seqlen
if
inference_params
is
not
None
else
None
qkv
=
self
.
in_proj
(
x
)
if
self
.
mlp_dim
>
0
:
qkv
,
x_mlp
=
qkv
.
split
([
qkv
.
shape
[
-
1
]
-
self
.
mlp_dim
,
self
.
mlp_dim
],
dim
=-
1
)
x_mlp_up
,
x_mlp_gate
=
x_mlp
.
chunk
(
2
,
dim
=-
1
)
x_mlp
=
x_mlp_up
*
F
.
silu
(
x_mlp_gate
)
if
self
.
d_conv
>
0
:
# The inference code for conv1d is pretty messy, should clean it up
if
(
inference_params
is
None
or
inference_params
.
seqlen_offset
==
0
):
if
causal_conv1d_fn
is
None
:
qkv
=
rearrange
(
self
.
conv1d
(
rearrange
(
qkv
,
"b s d -> b d s"
))[...,
:
-
(
self
.
d_conv
-
1
)],
"b d s -> b s d"
).
contiguous
()
else
:
qkv
=
causal_conv1d_fn
(
qkv
.
transpose
(
1
,
2
),
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
self
.
conv1d
.
bias
).
transpose
(
1
,
2
)
if
inference_params
is
not
None
:
_
,
conv_state
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
# If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
qkv_t
=
rearrange
(
qkv
,
"b l d -> b d l"
)
conv_state
.
copy_
(
F
.
pad
(
qkv_t
,
(
self
.
d_conv
-
qkv_t
.
shape
[
-
1
],
0
)))
# Update state (B D W)
else
:
_
,
conv_state
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
assert
qkv
.
shape
[
1
]
==
1
,
"Only support decoding with 1 token at a time for now"
qkv
=
qkv
.
squeeze
(
1
)
# Conv step
if
causal_conv1d_update
is
None
:
conv_state
.
copy_
(
torch
.
roll
(
conv_state
,
shifts
=-
1
,
dims
=-
1
))
# Update state (B D W)
conv_state
[:,
:,
-
1
]
=
qkv
qkv
=
torch
.
sum
(
conv_state
*
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
dim
=-
1
)
# (B D)
if
self
.
conv1d
.
bias
is
not
None
:
qkv
=
qkv
+
self
.
conv1d
.
bias
else
:
qkv
=
causal_conv1d_update
(
qkv
,
conv_state
,
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
self
.
conv1d
.
bias
)
qkv
=
qkv
.
unsqueeze
(
1
)
q
,
kv
=
qkv
.
split
([
self
.
num_heads
*
self
.
head_dim
,
self
.
num_heads_kv
*
2
*
self
.
head_dim
],
dim
=-
1
)
q
=
rearrange
(
q
,
"... (h d) -> ... h d"
,
d
=
self
.
head_dim
)
kv
=
rearrange
(
kv
,
"... (two hkv d) -> ... two hkv d"
,
two
=
2
,
d
=
self
.
head_dim
)
if
(
inference_params
is
None
or
inference_params
.
seqlen_offset
==
0
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
):
if
self
.
rotary_emb_dim
>
0
:
q
,
kv
=
self
.
rotary_emb
(
q
,
kv
,
seqlen_offset
=
seqlen_offset
,
max_seqlen
=
rotary_max_seqlen
)
if
inference_params
is
None
:
k
,
v
=
kv
.
unbind
(
dim
=-
3
)
k
=
torch
.
repeat_interleave
(
k
,
dim
=
2
,
repeats
=
self
.
num_heads
//
self
.
num_heads_kv
)
v
=
torch
.
repeat_interleave
(
v
,
dim
=
2
,
repeats
=
self
.
num_heads
//
self
.
num_heads_kv
)
context
=
F
.
scaled_dot_product_attention
(
q
.
transpose
(
1
,
2
),
k
.
transpose
(
1
,
2
),
v
.
transpose
(
1
,
2
),
is_causal
=
self
.
causal
,
scale
=
self
.
softmax_scale
).
transpose
(
1
,
2
)
else
:
context
=
self
.
_update_kvcache_attention
(
q
,
kv
,
inference_params
)
else
:
context
=
self
.
_apply_rotary_update_kvcache_attention
(
q
,
kv
,
inference_params
)
context
=
rearrange
(
context
,
"... h d -> ... (h d)"
)
if
self
.
mlp_dim
>
0
:
context
=
torch
.
cat
([
context
,
x_mlp
],
dim
=-
1
)
out
=
self
.
out_proj
(
context
)
return
out
mamba/mamba_ssm/modules/mlp.py
0 → 100644
View file @
2eefe3d6
# Copyright (c) 2024, Tri Dao, Albert Gu.
from
torch
import
nn
from
torch.nn
import
functional
as
F
class
GatedMLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
silu
,
bias
=
False
,
multiple_of
=
128
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
(
hidden_features
if
hidden_features
is
not
None
else
int
(
8
*
in_features
/
3
)
)
hidden_features
=
(
hidden_features
+
multiple_of
-
1
)
//
multiple_of
*
multiple_of
self
.
fc1
=
nn
.
Linear
(
in_features
,
2
*
hidden_features
,
bias
=
bias
,
**
factory_kwargs
)
self
.
activation
=
activation
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
y
,
gate
=
y
.
chunk
(
2
,
dim
=-
1
)
y
=
y
*
self
.
activation
(
gate
)
y
=
self
.
fc2
(
y
)
return
y
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment