Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
MambaVision_pytorch
Commits
2eefe3d6
Commit
2eefe3d6
authored
Sep 29, 2024
by
luopl
Browse files
add mamba
parent
b7535e7c
Pipeline
#1735
failed with stages
in 0 seconds
Changes
65
Pipelines
1
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2606 additions
and
0 deletions
+2606
-0
mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu
mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu
+10
-0
mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu
mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu
+10
-0
mamba/csrc/selective_scan/selective_scan_fwd_kernel.cuh
mamba/csrc/selective_scan/selective_scan_fwd_kernel.cuh
+376
-0
mamba/csrc/selective_scan/static_switch.h
mamba/csrc/selective_scan/static_switch.h
+25
-0
mamba/csrc/selective_scan/uninitialized_copy.cuh
mamba/csrc/selective_scan/uninitialized_copy.cuh
+77
-0
mamba/evals/lm_harness_eval.py
mamba/evals/lm_harness_eval.py
+39
-0
mamba/mamba_ssm/__init__.py
mamba/mamba_ssm/__init__.py
+6
-0
mamba/mamba_ssm/distributed/__init__.py
mamba/mamba_ssm/distributed/__init__.py
+0
-0
mamba/mamba_ssm/distributed/distributed_utils.py
mamba/mamba_ssm/distributed/distributed_utils.py
+144
-0
mamba/mamba_ssm/distributed/tensor_parallel.py
mamba/mamba_ssm/distributed/tensor_parallel.py
+296
-0
mamba/mamba_ssm/models/__init__.py
mamba/mamba_ssm/models/__init__.py
+0
-0
mamba/mamba_ssm/models/config_mamba.py
mamba/mamba_ssm/models/config_mamba.py
+18
-0
mamba/mamba_ssm/models/mixer_seq_simple.py
mamba/mamba_ssm/models/mixer_seq_simple.py
+309
-0
mamba/mamba_ssm/modules/__init__.py
mamba/mamba_ssm/modules/__init__.py
+0
-0
mamba/mamba_ssm/modules/block.py
mamba/mamba_ssm/modules/block.py
+91
-0
mamba/mamba_ssm/modules/mamba2.py
mamba/mamba_ssm/modules/mamba2.py
+383
-0
mamba/mamba_ssm/modules/mamba2_simple.py
mamba/mamba_ssm/modules/mamba2_simple.py
+200
-0
mamba/mamba_ssm/modules/mamba_simple.py
mamba/mamba_ssm/modules/mamba_simple.py
+294
-0
mamba/mamba_ssm/modules/mha.py
mamba/mamba_ssm/modules/mha.py
+294
-0
mamba/mamba_ssm/modules/mlp.py
mamba/mamba_ssm/modules/mlp.py
+34
-0
No files found.
mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu
0 → 100644
View file @
2eefe3d6
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// Split into multiple files to compile in paralell
#include "selective_scan_fwd_kernel.cuh"
template
void
selective_scan_fwd_cuda
<
at
::
Half
,
float
>(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
template
void
selective_scan_fwd_cuda
<
at
::
Half
,
complex_t
>(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
\ No newline at end of file
mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu
0 → 100644
View file @
2eefe3d6
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// Split into multiple files to compile in paralell
#include "selective_scan_fwd_kernel.cuh"
template
void
selective_scan_fwd_cuda
<
float
,
float
>(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
template
void
selective_scan_fwd_cuda
<
float
,
complex_t
>(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
\ No newline at end of file
mamba/csrc/selective_scan/selective_scan_fwd_kernel.cuh
0 → 100644
View file @
2eefe3d6
This diff is collapsed.
Click to expand it.
mamba/csrc/selective_scan/static_switch.h
0 → 100644
View file @
2eefe3d6
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
mamba/csrc/selective_scan/uninitialized_copy.cuh
0 → 100644
View file @
2eefe3d6
/******************************************************************************
* Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#ifndef USE_ROCM
#include <cub/config.cuh>
#include <cuda/std/type_traits>
#else
#include <hipcub/hipcub.hpp>
// Map ::cuda::std to the standard std namespace
namespace
cuda
{
namespace
std
=
::
std
;
}
#endif
namespace
detail
{
#if defined(_NVHPC_CUDA)
template
<
typename
T
,
typename
U
>
__host__
__device__
void
uninitialized_copy
(
T
*
ptr
,
U
&&
val
)
{
// NVBug 3384810
new
(
ptr
)
T
(
::
cuda
::
std
::
forward
<
U
>
(
val
));
}
#else
template
<
typename
T
,
typename
U
,
typename
::
cuda
::
std
::
enable_if
<
::
cuda
::
std
::
is_trivially_copyable
<
T
>
::
value
,
int
>::
type
=
0
>
__host__
__device__
void
uninitialized_copy
(
T
*
ptr
,
U
&&
val
)
{
*
ptr
=
::
cuda
::
std
::
forward
<
U
>
(
val
);
}
template
<
typename
T
,
typename
U
,
typename
::
cuda
::
std
::
enable_if
<
!::
cuda
::
std
::
is_trivially_copyable
<
T
>
::
value
,
int
>::
type
=
0
>
__host__
__device__
void
uninitialized_copy
(
T
*
ptr
,
U
&&
val
)
{
new
(
ptr
)
T
(
::
cuda
::
std
::
forward
<
U
>
(
val
));
}
#endif
}
// namespace detail
mamba/evals/lm_harness_eval.py
0 → 100644
View file @
2eefe3d6
import
torch
import
transformers
from
transformers
import
AutoTokenizer
from
mamba_ssm.models.mixer_seq_simple
import
MambaLMHeadModel
from
lm_eval.api.model
import
LM
from
lm_eval.models.huggingface
import
HFLM
from
lm_eval.api.registry
import
register_model
from
lm_eval.__main__
import
cli_evaluate
@
register_model
(
"mamba"
)
class
MambaEvalWrapper
(
HFLM
):
AUTO_MODEL_CLASS
=
transformers
.
AutoModelForCausalLM
def
__init__
(
self
,
pretrained
=
"state-spaces/mamba-2.8b"
,
max_length
=
2048
,
batch_size
=
None
,
device
=
"cuda"
,
dtype
=
torch
.
float16
):
LM
.
__init__
(
self
)
self
.
_model
=
MambaLMHeadModel
.
from_pretrained
(
pretrained
,
device
=
device
,
dtype
=
dtype
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"EleutherAI/gpt-neox-20b"
)
self
.
tokenizer
.
pad_token_id
=
self
.
tokenizer
.
eos_token_id
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
self
.
_batch_size
=
int
(
batch_size
)
if
batch_size
is
not
None
else
64
self
.
_max_length
=
max_length
self
.
_device
=
torch
.
device
(
device
)
@
property
def
batch_size
(
self
):
return
self
.
_batch_size
def
_model_generate
(
self
,
context
,
max_length
,
stop
,
**
generation_kwargs
):
raise
NotImplementedError
()
if
__name__
==
"__main__"
:
cli_evaluate
()
mamba/mamba_ssm/__init__.py
0 → 100644
View file @
2eefe3d6
__version__
=
"2.2.2"
from
mamba_ssm.ops.selective_scan_interface
import
selective_scan_fn
,
mamba_inner_fn
from
mamba_ssm.modules.mamba_simple
import
Mamba
from
mamba_ssm.modules.mamba2
import
Mamba2
from
mamba_ssm.models.mixer_seq_simple
import
MambaLMHeadModel
mamba/mamba_ssm/distributed/__init__.py
0 → 100644
View file @
2eefe3d6
mamba/mamba_ssm/distributed/distributed_utils.py
0 → 100644
View file @
2eefe3d6
from
typing
import
Optional
import
torch
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward compatibility with
# older PyTorch.
if
"all_gather_into_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
all_gather_into_tensor
=
torch
.
distributed
.
_all_gather_base
if
"reduce_scatter_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
reduce_scatter_tensor
=
torch
.
distributed
.
_reduce_scatter_base
# Raw operation, does not support autograd, but does support async
def
all_gather_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
output
=
torch
.
empty
(
world_size
*
input_
.
shape
[
0
],
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
handle
=
torch
.
distributed
.
all_gather_into_tensor
(
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
return
output
,
handle
# Raw operation, does not support autograd, but does support async
def
reduce_scatter_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
assert
input_
.
shape
[
0
]
%
world_size
==
0
output
=
torch
.
empty
(
input_
.
shape
[
0
]
//
world_size
,
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
handle
=
torch
.
distributed
.
reduce_scatter_tensor
(
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
return
output
,
handle
# Raw operation, does not support autograd, but does support async
def
all_reduce_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
input_
=
input_
.
contiguous
()
handle
=
torch
.
distributed
.
all_reduce
(
input_
,
group
=
process_group
,
async_op
=
async_op
)
return
input_
,
handle
class
AllGatherFunc
(
torch
.
autograd
.
Function
):
"""Gather the input from sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
all_gather_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
grad_input
,
_
=
reduce_scatter_raw
(
grad_output
,
ctx
.
process_group
)
return
grad_input
,
None
# Supports autograd, but does not support async
all_gather
=
AllGatherFunc
.
apply
class
ReduceScatterFunc
(
torch
.
autograd
.
Function
):
"""Reduce scatter the input from the sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
reduce_scatter_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
grad_input
,
_
=
all_gather_raw
(
grad_output
,
ctx
.
process_group
)
return
grad_input
,
None
# Supports autograd, but does not support async
reduce_scatter
=
ReduceScatterFunc
.
apply
class
AllReduceFunc
(
torch
.
autograd
.
Function
):
"""Gather the input from sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
all_reduce_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
return
grad_output
,
None
# Supports autograd, but does not support async
all_reduce
=
AllReduceFunc
.
apply
def
sync_shared_params
(
model
:
torch
.
nn
.
Module
,
process_group
:
ProcessGroup
):
# We want to iterate over parameters with _shared_params=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
pamams_shared
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
if
getattr
(
p
,
"_shared_params"
,
False
)
}
for
_
,
p
in
sorted
(
pamams_shared
.
items
()):
with
torch
.
no_grad
():
# Broadcast needs src to be global rank, not group rank
torch
.
distributed
.
broadcast
(
p
,
src
=
torch
.
distributed
.
get_global_rank
(
process_group
,
0
),
group
=
process_group
)
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
def
allreduce_sequence_parallel_grad
(
model
:
torch
.
nn
.
Module
,
process_group
:
ProcessGroup
):
# We want to iterate over parameters with _sequence_parallel=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
if
getattr
(
p
,
"_sequence_parallel"
,
False
)
}
grads
=
[
p
.
grad
for
_
,
p
in
sorted
(
params_seqparallel
.
items
())]
if
grads
:
with
torch
.
no_grad
():
coalesced
=
torch
.
_utils
.
_flatten_dense_tensors
(
grads
)
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
process_group
)
for
buf
,
synced
in
zip
(
grads
,
torch
.
_utils
.
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
def
get_dim_for_local_rank
(
dim
:
int
,
world_size
:
int
,
local_rank
:
int
,
multiple_of
:
int
=
1
)
->
int
:
"""Get the dim for the local rank derived from splitting dim on world_size processes.
The split may not be even across the world_size processes.
"""
multiple
=
dim
//
multiple_of
div
=
multiple
//
world_size
mod
=
multiple
%
world_size
local_multiple
=
div
+
int
(
local_rank
<
mod
)
return
local_multiple
*
multiple_of
mamba/mamba_ssm/distributed/tensor_parallel.py
0 → 100644
View file @
2eefe3d6
# Copyright (c) 2024, Tri Dao.
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.distributed
import
ProcessGroup
from
einops
import
rearrange
from
mamba_ssm.distributed.distributed_utils
import
(
all_gather_raw
,
all_reduce
,
all_reduce_raw
,
reduce_scatter
,
reduce_scatter_raw
,
)
class
ParallelLinearFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight
,
bias
,
process_group
=
None
,
sequence_parallel
=
True
):
"""
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
"""
ctx
.
compute_weight_gradient
=
weight
.
requires_grad
ctx
.
process_group
=
process_group
ctx
.
sequence_parallel
=
sequence_parallel
if
torch
.
is_autocast_enabled
():
x
=
x
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
x
=
x
.
contiguous
()
if
process_group
is
not
None
and
sequence_parallel
:
# We want to kick off the all_gather early, before weight dtype conversion
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
else
:
total_x
=
x
if
torch
.
is_autocast_enabled
():
weight
=
weight
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
bias
=
bias
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
if
bias
is
not
None
else
None
weight
=
weight
.
contiguous
()
if
process_group
is
not
None
and
sequence_parallel
:
handle_x
.
wait
()
batch_shape
,
n
=
total_x
.
shape
[:
-
1
],
total_x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
output
=
F
.
linear
(
total_x
,
weight
,
bias
)
if
ctx
.
compute_weight_gradient
:
ctx
.
save_for_backward
(
x
,
weight
)
else
:
ctx
.
save_for_backward
(
weight
)
return
output
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
grad_output
=
grad_output
.
contiguous
()
process_group
=
ctx
.
process_group
sequence_parallel
=
ctx
.
sequence_parallel
if
ctx
.
compute_weight_gradient
:
x
,
weight
=
ctx
.
saved_tensors
if
process_group
is
not
None
and
sequence_parallel
:
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
else
:
total_x
=
x
else
:
(
weight
,)
=
ctx
.
saved_tensors
total_x
=
None
batch_shape
=
grad_output
.
shape
[:
-
1
]
batch_dim
=
batch_shape
.
numel
()
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
if
ctx
.
needs_input_grad
[
0
]:
grad_input
=
F
.
linear
(
grad_output
,
weight
.
t
())
grad_input
=
grad_input
.
reshape
(
*
batch_shape
,
grad_input
.
shape
[
-
1
])
if
process_group
is
not
None
:
reduce_fn
=
reduce_scatter_raw
if
sequence_parallel
else
all_reduce_raw
grad_input
,
handle_grad_input
=
reduce_fn
(
grad_input
,
process_group
,
async_op
=
True
)
else
:
grad_input
=
None
if
ctx
.
needs_input_grad
[
1
]:
assert
ctx
.
compute_weight_gradient
if
process_group
is
not
None
and
sequence_parallel
:
handle_x
.
wait
()
grad_weight
=
torch
.
einsum
(
"bo,bi->oi"
,
grad_output
,
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
])
)
else
:
grad_weight
=
None
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
ctx
.
needs_input_grad
[
2
]
else
None
if
process_group
is
not
None
and
ctx
.
needs_input_grad
[
0
]:
handle_grad_input
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
def
parallel_linear_func
(
x
:
Tensor
,
weight
:
Tensor
,
bias
:
Optional
[
Tensor
]
=
None
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
sequence_parallel
:
bool
=
True
,
):
return
ParallelLinearFunc
.
apply
(
x
,
weight
,
bias
,
process_group
,
sequence_parallel
)
class
ColumnParallelLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
process_group
:
ProcessGroup
,
bias
:
bool
=
True
,
sequence_parallel
=
True
,
multiple_of
=
1
,
device
=
None
,
dtype
=
None
,
)
->
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
out_features
%
multiple_of
:
raise
ValueError
(
f
"out_features (
{
out_features
}
) must be a multiple of
{
multiple_of
}
"
)
multiple
=
out_features
//
multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
div
=
multiple
//
world_size
mod
=
multiple
%
world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple
=
div
+
int
(
torch
.
distributed
.
get_rank
(
process_group
)
<
mod
)
super
().
__init__
(
in_features
,
local_multiple
*
multiple_of
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
)
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
def
forward
(
self
,
x
):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
return
parallel_linear_func
(
x
,
self
.
weight
,
self
.
bias
,
process_group
=
self
.
process_group
,
sequence_parallel
=
self
.
sequence_parallel
,
)
class
RowParallelLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
process_group
:
ProcessGroup
,
bias
:
bool
=
True
,
sequence_parallel
=
True
,
multiple_of
=
1
,
device
=
None
,
dtype
=
None
,
)
->
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
rank
=
torch
.
distributed
.
get_rank
(
process_group
)
if
in_features
%
multiple_of
:
raise
ValueError
(
f
"in_features (
{
in_features
}
) must be a multiple of
{
multiple_of
}
"
)
multiple
=
in_features
//
multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
div
=
multiple
//
world_size
mod
=
multiple
%
world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple
=
div
+
int
(
torch
.
distributed
.
get_rank
(
process_group
)
<
mod
)
# Only rank 0 will have bias
super
().
__init__
(
local_multiple
*
multiple_of
,
out_features
,
bias
=
bias
and
rank
==
0
,
device
=
device
,
dtype
=
dtype
,
)
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
def
forward
(
self
,
x
):
"""
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
a reduce_scatter of the result.
"""
out
=
parallel_linear_func
(
x
,
self
.
weight
,
self
.
bias
)
reduce_fn
=
reduce_scatter
if
self
.
sequence_parallel
else
all_reduce
return
reduce_fn
(
out
,
self
.
process_group
)
class
VocabParallelEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
*
args
,
process_group
=
None
,
padding_idx
=
None
,
**
kwargs
):
self
.
process_group
=
process_group
if
process_group
is
not
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
num_embeddings
%
world_size
!=
0
:
raise
ValueError
(
f
"num_embeddings (
{
num_embeddings
}
) must be divisible by "
f
"world_size (
{
world_size
}
)"
)
if
world_size
>
1
and
padding_idx
is
not
None
:
raise
RuntimeError
(
"ParallelEmbedding does not support padding_idx"
)
else
:
world_size
=
1
super
().
__init__
(
num_embeddings
//
world_size
,
*
args
,
padding_idx
=
padding_idx
,
**
kwargs
)
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
if
self
.
process_group
is
None
:
return
super
().
forward
(
input
)
else
:
rank
=
torch
.
distributed
.
get_rank
(
self
.
process_group
)
vocab_size
=
self
.
num_embeddings
vocab_start_index
,
vocab_end_index
=
rank
*
vocab_size
,
(
rank
+
1
)
*
vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
input_ids_mask
=
(
input
<
vocab_start_index
)
|
(
input
>=
vocab_end_index
)
input
=
input
-
vocab_start_index
input
[
input_ids_mask
]
=
0
embeddings
=
super
().
forward
(
input
)
embeddings
[
input_ids_mask
]
=
0.0
return
embeddings
class
ColumnParallelEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
*
args
,
process_group
=
None
,
**
kwargs
):
self
.
process_group
=
process_group
if
process_group
is
not
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
embedding_dim
%
world_size
!=
0
:
raise
ValueError
(
f
"embedding_dim (
{
embedding_dim
}
) must be divisible by "
f
"world_size (
{
world_size
}
)"
)
else
:
world_size
=
1
super
().
__init__
(
num_embeddings
,
embedding_dim
//
world_size
,
*
args
,
**
kwargs
)
class
ParallelEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
process_group
,
padding_idx
=
None
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
self
.
word_embeddings
=
VocabParallelEmbedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
process_group
=
process_group
,
**
factory_kwargs
,
)
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
max_position_embeddings
>
0
:
self
.
position_embeddings
=
ColumnParallelEmbedding
(
max_position_embeddings
,
embed_dim
,
process_group
=
process_group
,
**
factory_kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
combine_batch_seqlen_dim
=
False
):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size
,
seqlen
=
input_ids
.
shape
world_size
=
torch
.
distributed
.
get_world_size
(
self
.
process_group
)
embeddings
=
self
.
word_embeddings
(
input_ids
)
if
self
.
max_position_embeddings
>
0
:
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
if
world_size
<=
1
:
embeddings
=
embeddings
+
position_embeddings
else
:
partition_dim
=
self
.
position_embeddings
.
embedding_dim
rank
=
torch
.
distributed
.
get_rank
(
self
.
process_group
)
embeddings
[
...,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
]
+=
position_embeddings
if
combine_batch_seqlen_dim
:
embeddings
=
rearrange
(
embeddings
,
"b s d -> (b s) d"
)
reduce_fn
=
reduce_scatter
if
self
.
sequence_parallel
else
all_reduce
return
embeddings
if
world_size
<=
1
else
reduce_fn
(
embeddings
,
self
.
process_group
)
mamba/mamba_ssm/models/__init__.py
0 → 100644
View file @
2eefe3d6
mamba/mamba_ssm/models/config_mamba.py
0 → 100644
View file @
2eefe3d6
from
dataclasses
import
dataclass
,
field
@
dataclass
class
MambaConfig
:
d_model
:
int
=
2560
d_intermediate
:
int
=
0
n_layer
:
int
=
64
vocab_size
:
int
=
50277
ssm_cfg
:
dict
=
field
(
default_factory
=
dict
)
attn_layer_idx
:
list
=
field
(
default_factory
=
list
)
attn_cfg
:
dict
=
field
(
default_factory
=
dict
)
rms_norm
:
bool
=
True
residual_in_fp32
:
bool
=
True
fused_add_norm
:
bool
=
True
pad_vocab_size_multiple
:
int
=
8
tie_embeddings
:
bool
=
True
mamba/mamba_ssm/models/mixer_seq_simple.py
0 → 100644
View file @
2eefe3d6
# Copyright (c) 2023, Albert Gu, Tri Dao.
import
math
from
functools
import
partial
import
json
import
os
import
copy
from
collections
import
namedtuple
import
torch
import
torch.nn
as
nn
from
mamba_ssm.models.config_mamba
import
MambaConfig
from
mamba_ssm.modules.mamba_simple
import
Mamba
from
mamba_ssm.modules.mamba2
import
Mamba2
from
mamba_ssm.modules.mha
import
MHA
from
mamba_ssm.modules.mlp
import
GatedMLP
from
mamba_ssm.modules.block
import
Block
from
mamba_ssm.utils.generation
import
GenerationMixin
from
mamba_ssm.utils.hf
import
load_config_hf
,
load_state_dict_hf
try
:
from
mamba_ssm.ops.triton.layer_norm
import
RMSNorm
,
layer_norm_fn
,
rms_norm_fn
except
ImportError
:
RMSNorm
,
layer_norm_fn
,
rms_norm_fn
=
None
,
None
,
None
def
create_block
(
d_model
,
d_intermediate
,
ssm_cfg
=
None
,
attn_layer_idx
=
None
,
attn_cfg
=
None
,
norm_epsilon
=
1e-5
,
rms_norm
=
False
,
residual_in_fp32
=
False
,
fused_add_norm
=
False
,
layer_idx
=
None
,
device
=
None
,
dtype
=
None
,
):
if
ssm_cfg
is
None
:
ssm_cfg
=
{}
if
attn_layer_idx
is
None
:
attn_layer_idx
=
[]
if
attn_cfg
is
None
:
attn_cfg
=
{}
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
if
layer_idx
not
in
attn_layer_idx
:
# Create a copy of the config to modify
ssm_cfg
=
copy
.
deepcopy
(
ssm_cfg
)
if
ssm_cfg
is
not
None
else
{}
ssm_layer
=
ssm_cfg
.
pop
(
"layer"
,
"Mamba1"
)
if
ssm_layer
not
in
[
"Mamba1"
,
"Mamba2"
]:
raise
ValueError
(
f
"Invalid ssm_layer:
{
ssm_layer
}
, only support Mamba1 and Mamba2"
)
mixer_cls
=
partial
(
Mamba2
if
ssm_layer
==
"Mamba2"
else
Mamba
,
layer_idx
=
layer_idx
,
**
ssm_cfg
,
**
factory_kwargs
)
else
:
mixer_cls
=
partial
(
MHA
,
layer_idx
=
layer_idx
,
**
attn_cfg
,
**
factory_kwargs
)
norm_cls
=
partial
(
nn
.
LayerNorm
if
not
rms_norm
else
RMSNorm
,
eps
=
norm_epsilon
,
**
factory_kwargs
)
if
d_intermediate
==
0
:
mlp_cls
=
nn
.
Identity
else
:
mlp_cls
=
partial
(
GatedMLP
,
hidden_features
=
d_intermediate
,
out_features
=
d_model
,
**
factory_kwargs
)
block
=
Block
(
d_model
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_cls
,
fused_add_norm
=
fused_add_norm
,
residual_in_fp32
=
residual_in_fp32
,
)
block
.
layer_idx
=
layer_idx
return
block
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def
_init_weights
(
module
,
n_layer
,
initializer_range
=
0.02
,
# Now only used for embedding layer.
rescale_prenorm_residual
=
True
,
n_residuals_per_layer
=
1
,
# Change to 2 if we have MLP
):
if
isinstance
(
module
,
nn
.
Linear
):
if
module
.
bias
is
not
None
:
if
not
getattr
(
module
.
bias
,
"_no_reinit"
,
False
):
nn
.
init
.
zeros_
(
module
.
bias
)
elif
isinstance
(
module
,
nn
.
Embedding
):
nn
.
init
.
normal_
(
module
.
weight
,
std
=
initializer_range
)
if
rescale_prenorm_residual
:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for
name
,
p
in
module
.
named_parameters
():
if
name
in
[
"out_proj.weight"
,
"fc2.weight"
]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
nn
.
init
.
kaiming_uniform_
(
p
,
a
=
math
.
sqrt
(
5
))
with
torch
.
no_grad
():
p
/=
math
.
sqrt
(
n_residuals_per_layer
*
n_layer
)
class
MixerModel
(
nn
.
Module
):
def
__init__
(
self
,
d_model
:
int
,
n_layer
:
int
,
d_intermediate
:
int
,
vocab_size
:
int
,
ssm_cfg
=
None
,
attn_layer_idx
=
None
,
attn_cfg
=
None
,
norm_epsilon
:
float
=
1e-5
,
rms_norm
:
bool
=
False
,
initializer_cfg
=
None
,
fused_add_norm
=
False
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
,
)
->
None
:
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
residual_in_fp32
=
residual_in_fp32
self
.
embedding
=
nn
.
Embedding
(
vocab_size
,
d_model
,
**
factory_kwargs
)
# We change the order of residual and layer norm:
# Instead of LN -> Attn / MLP -> Add, we do:
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
# the main branch (output of MLP / Mixer). The model definition is unchanged.
# This is for performance reason: we can fuse add + layer_norm.
self
.
fused_add_norm
=
fused_add_norm
if
self
.
fused_add_norm
:
if
layer_norm_fn
is
None
or
rms_norm_fn
is
None
:
raise
ImportError
(
"Failed to import Triton LayerNorm / RMSNorm kernels"
)
self
.
layers
=
nn
.
ModuleList
(
[
create_block
(
d_model
,
d_intermediate
=
d_intermediate
,
ssm_cfg
=
ssm_cfg
,
attn_layer_idx
=
attn_layer_idx
,
attn_cfg
=
attn_cfg
,
norm_epsilon
=
norm_epsilon
,
rms_norm
=
rms_norm
,
residual_in_fp32
=
residual_in_fp32
,
fused_add_norm
=
fused_add_norm
,
layer_idx
=
i
,
**
factory_kwargs
,
)
for
i
in
range
(
n_layer
)
]
)
self
.
norm_f
=
(
nn
.
LayerNorm
if
not
rms_norm
else
RMSNorm
)(
d_model
,
eps
=
norm_epsilon
,
**
factory_kwargs
)
self
.
apply
(
partial
(
_init_weights
,
n_layer
=
n_layer
,
**
(
initializer_cfg
if
initializer_cfg
is
not
None
else
{}),
n_residuals_per_layer
=
1
if
d_intermediate
==
0
else
2
,
# 2 if we have MLP
)
)
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
return
{
i
:
layer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
for
i
,
layer
in
enumerate
(
self
.
layers
)
}
def
forward
(
self
,
input_ids
,
inference_params
=
None
,
**
mixer_kwargs
):
hidden_states
=
self
.
embedding
(
input_ids
)
residual
=
None
for
layer
in
self
.
layers
:
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
inference_params
=
inference_params
,
**
mixer_kwargs
)
if
not
self
.
fused_add_norm
:
residual
=
(
hidden_states
+
residual
)
if
residual
is
not
None
else
hidden_states
hidden_states
=
self
.
norm_f
(
residual
.
to
(
dtype
=
self
.
norm_f
.
weight
.
dtype
))
else
:
# Set prenorm=False here since we don't need the residual
hidden_states
=
layer_norm_fn
(
hidden_states
,
self
.
norm_f
.
weight
,
self
.
norm_f
.
bias
,
eps
=
self
.
norm_f
.
eps
,
residual
=
residual
,
prenorm
=
False
,
residual_in_fp32
=
self
.
residual_in_fp32
,
is_rms_norm
=
isinstance
(
self
.
norm_f
,
RMSNorm
)
)
return
hidden_states
class
MambaLMHeadModel
(
nn
.
Module
,
GenerationMixin
):
def
__init__
(
self
,
config
:
MambaConfig
,
initializer_cfg
=
None
,
device
=
None
,
dtype
=
None
,
)
->
None
:
self
.
config
=
config
d_model
=
config
.
d_model
n_layer
=
config
.
n_layer
d_intermediate
=
config
.
d_intermediate
vocab_size
=
config
.
vocab_size
ssm_cfg
=
config
.
ssm_cfg
attn_layer_idx
=
config
.
attn_layer_idx
attn_cfg
=
config
.
attn_cfg
rms_norm
=
config
.
rms_norm
residual_in_fp32
=
config
.
residual_in_fp32
fused_add_norm
=
config
.
fused_add_norm
pad_vocab_size_multiple
=
config
.
pad_vocab_size_multiple
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
if
vocab_size
%
pad_vocab_size_multiple
!=
0
:
vocab_size
+=
pad_vocab_size_multiple
-
(
vocab_size
%
pad_vocab_size_multiple
)
self
.
backbone
=
MixerModel
(
d_model
=
d_model
,
n_layer
=
n_layer
,
d_intermediate
=
d_intermediate
,
vocab_size
=
vocab_size
,
ssm_cfg
=
ssm_cfg
,
attn_layer_idx
=
attn_layer_idx
,
attn_cfg
=
attn_cfg
,
rms_norm
=
rms_norm
,
initializer_cfg
=
initializer_cfg
,
fused_add_norm
=
fused_add_norm
,
residual_in_fp32
=
residual_in_fp32
,
**
factory_kwargs
,
)
self
.
lm_head
=
nn
.
Linear
(
d_model
,
vocab_size
,
bias
=
False
,
**
factory_kwargs
)
# Initialize weights and apply final processing
self
.
apply
(
partial
(
_init_weights
,
n_layer
=
n_layer
,
**
(
initializer_cfg
if
initializer_cfg
is
not
None
else
{}),
)
)
self
.
tie_weights
()
def
tie_weights
(
self
):
if
self
.
config
.
tie_embeddings
:
self
.
lm_head
.
weight
=
self
.
backbone
.
embedding
.
weight
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
return
self
.
backbone
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
,
num_last_tokens
=
0
,
**
mixer_kwargs
):
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states
=
self
.
backbone
(
input_ids
,
inference_params
=
inference_params
,
**
mixer_kwargs
)
if
num_last_tokens
>
0
:
hidden_states
=
hidden_states
[:,
-
num_last_tokens
:]
lm_logits
=
self
.
lm_head
(
hidden_states
)
CausalLMOutput
=
namedtuple
(
"CausalLMOutput"
,
[
"logits"
])
return
CausalLMOutput
(
logits
=
lm_logits
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name
,
device
=
None
,
dtype
=
None
,
**
kwargs
):
config_data
=
load_config_hf
(
pretrained_model_name
)
config
=
MambaConfig
(
**
config_data
)
model
=
cls
(
config
,
device
=
device
,
dtype
=
dtype
,
**
kwargs
)
model
.
load_state_dict
(
load_state_dict_hf
(
pretrained_model_name
,
device
=
device
,
dtype
=
dtype
))
return
model
def
save_pretrained
(
self
,
save_directory
):
"""
Minimal implementation of save_pretrained for MambaLMHeadModel.
Save the model and its configuration file to a directory.
"""
# Ensure save_directory exists
os
.
makedirs
(
save_directory
,
exist_ok
=
True
)
# Save the model's state_dict
model_path
=
os
.
path
.
join
(
save_directory
,
'pytorch_model.bin'
)
torch
.
save
(
self
.
state_dict
(),
model_path
)
# Save the configuration of the model
config_path
=
os
.
path
.
join
(
save_directory
,
'config.json'
)
with
open
(
config_path
,
'w'
)
as
f
:
json
.
dump
(
self
.
config
.
__dict__
,
f
,
indent
=
4
)
mamba/mamba_ssm/modules/__init__.py
0 → 100644
View file @
2eefe3d6
mamba/mamba_ssm/modules/block.py
0 → 100644
View file @
2eefe3d6
# Copyright (c) 2024, Tri Dao, Albert Gu.
from
typing
import
Optional
import
torch
from
torch
import
nn
,
Tensor
from
mamba_ssm.ops.triton.layer_norm
import
RMSNorm
,
layer_norm_fn
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
mixer_cls
,
mlp_cls
,
norm_cls
=
nn
.
LayerNorm
,
fused_add_norm
=
False
,
residual_in_fp32
=
False
):
"""
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
This Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA/MLP -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Add -> LN -> Mixer, returning both
the hidden_states (output of the mixer) and the residual.
This is purely for performance reasons, as we can fuse add and LayerNorm.
The residual needs to be provided (except for the very first block).
"""
super
().
__init__
()
self
.
residual_in_fp32
=
residual_in_fp32
self
.
fused_add_norm
=
fused_add_norm
self
.
norm
=
norm_cls
(
dim
)
self
.
mixer
=
mixer_cls
(
dim
)
if
mlp_cls
is
not
nn
.
Identity
:
self
.
norm2
=
norm_cls
(
dim
)
self
.
mlp
=
mlp_cls
(
dim
)
else
:
self
.
mlp
=
None
if
self
.
fused_add_norm
:
assert
RMSNorm
is
not
None
,
"RMSNorm import fails"
assert
isinstance
(
self
.
norm
,
(
nn
.
LayerNorm
,
RMSNorm
)
),
"Only LayerNorm and RMSNorm are supported for fused_add_norm"
def
forward
(
self
,
hidden_states
:
Tensor
,
residual
:
Optional
[
Tensor
]
=
None
,
inference_params
=
None
,
**
mixer_kwargs
):
r
"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: hidden_states = Mixer(LN(residual))
"""
if
not
self
.
fused_add_norm
:
residual
=
(
hidden_states
+
residual
)
if
residual
is
not
None
else
hidden_states
hidden_states
=
self
.
norm
(
residual
.
to
(
dtype
=
self
.
norm
.
weight
.
dtype
))
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
else
:
hidden_states
,
residual
=
layer_norm_fn
(
hidden_states
,
self
.
norm
.
weight
,
self
.
norm
.
bias
,
residual
=
residual
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
,
eps
=
self
.
norm
.
eps
,
is_rms_norm
=
isinstance
(
self
.
norm
,
RMSNorm
)
)
hidden_states
=
self
.
mixer
(
hidden_states
,
inference_params
=
inference_params
,
**
mixer_kwargs
)
if
self
.
mlp
is
not
None
:
if
not
self
.
fused_add_norm
:
residual
=
hidden_states
+
residual
hidden_states
=
self
.
norm2
(
residual
.
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
))
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
else
:
hidden_states
,
residual
=
layer_norm_fn
(
hidden_states
,
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
residual
=
residual
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
,
eps
=
self
.
norm2
.
eps
,
is_rms_norm
=
isinstance
(
self
.
norm2
,
RMSNorm
)
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
return
self
.
mixer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
mamba/mamba_ssm/modules/mamba2.py
0 → 100644
View file @
2eefe3d6
# Copyright (c) 2024, Tri Dao, Albert Gu.
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
try
:
from
causal_conv1d
import
causal_conv1d_fn
,
causal_conv1d_update
except
ImportError
:
causal_conv1d_fn
,
causal_conv1d_update
=
None
,
None
try
:
from
causal_conv1d.causal_conv1d_varlen
import
causal_conv1d_varlen_states
except
ImportError
:
causal_conv1d_varlen_states
=
None
try
:
from
mamba_ssm.ops.triton.selective_state_update
import
selective_state_update
except
ImportError
:
selective_state_update
=
None
from
mamba_ssm.ops.triton.layernorm_gated
import
RMSNorm
as
RMSNormGated
from
mamba_ssm.distributed.tensor_parallel
import
ColumnParallelLinear
,
RowParallelLinear
from
mamba_ssm.distributed.distributed_utils
import
all_reduce
,
reduce_scatter
from
mamba_ssm.ops.triton.ssd_combined
import
mamba_chunk_scan_combined
from
mamba_ssm.ops.triton.ssd_combined
import
mamba_split_conv1d_scan_combined
from
huggingface_hub
import
PyTorchModelHubMixin
class
Mamba2
(
nn
.
Module
,
PyTorchModelHubMixin
):
def
__init__
(
self
,
d_model
,
d_state
=
128
,
d_conv
=
4
,
conv_init
=
None
,
expand
=
2
,
headdim
=
64
,
d_ssm
=
None
,
# If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
ngroups
=
1
,
A_init_range
=
(
1
,
16
),
D_has_hdim
=
False
,
rmsnorm
=
True
,
norm_before_gate
=
False
,
dt_min
=
0.001
,
dt_max
=
0.1
,
dt_init_floor
=
1e-4
,
dt_limit
=
(
0.0
,
float
(
"inf"
)),
bias
=
False
,
conv_bias
=
True
,
# Fused kernel and sharding options
chunk_size
=
256
,
use_mem_eff_path
=
True
,
layer_idx
=
None
,
# Absorb kwarg for general module
process_group
=
None
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
d_model
=
d_model
self
.
d_state
=
d_state
self
.
d_conv
=
d_conv
self
.
conv_init
=
conv_init
self
.
expand
=
expand
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
self
.
world_size
=
1
if
process_group
is
None
else
process_group
.
size
()
self
.
local_rank
=
0
if
process_group
is
None
else
process_group
.
rank
()
self
.
d_inner
=
(
self
.
expand
*
self
.
d_model
)
//
self
.
world_size
assert
self
.
d_inner
*
self
.
world_size
==
self
.
expand
*
self
.
d_model
self
.
headdim
=
headdim
self
.
d_ssm
=
self
.
d_inner
if
d_ssm
is
None
else
d_ssm
//
self
.
world_size
assert
ngroups
%
self
.
world_size
==
0
self
.
ngroups
=
ngroups
//
self
.
world_size
assert
self
.
d_ssm
%
self
.
headdim
==
0
self
.
nheads
=
self
.
d_ssm
//
self
.
headdim
self
.
D_has_hdim
=
D_has_hdim
self
.
rmsnorm
=
rmsnorm
self
.
norm_before_gate
=
norm_before_gate
self
.
dt_limit
=
dt_limit
self
.
activation
=
"silu"
self
.
chunk_size
=
chunk_size
self
.
use_mem_eff_path
=
use_mem_eff_path
self
.
layer_idx
=
layer_idx
# Order: [z, x, B, C, dt]
d_in_proj
=
2
*
self
.
d_inner
+
2
*
self
.
ngroups
*
self
.
d_state
+
self
.
nheads
if
self
.
process_group
is
None
:
self
.
in_proj
=
nn
.
Linear
(
self
.
d_model
,
d_in_proj
,
bias
=
bias
,
**
factory_kwargs
)
else
:
self
.
in_proj
=
ColumnParallelLinear
(
self
.
d_model
,
d_in_proj
*
self
.
world_size
,
bias
=
bias
,
process_group
=
self
.
process_group
,
sequence_parallel
=
self
.
sequence_parallel
,
**
factory_kwargs
)
conv_dim
=
self
.
d_ssm
+
2
*
self
.
ngroups
*
self
.
d_state
self
.
conv1d
=
nn
.
Conv1d
(
in_channels
=
conv_dim
,
out_channels
=
conv_dim
,
bias
=
conv_bias
,
kernel_size
=
d_conv
,
groups
=
conv_dim
,
padding
=
d_conv
-
1
,
**
factory_kwargs
,
)
if
self
.
conv_init
is
not
None
:
nn
.
init
.
uniform_
(
self
.
conv1d
.
weight
,
-
self
.
conv_init
,
self
.
conv_init
)
self
.
act
=
nn
.
SiLU
()
# Initialize log dt bias
dt
=
torch
.
exp
(
torch
.
rand
(
self
.
nheads
,
**
factory_kwargs
)
*
(
math
.
log
(
dt_max
)
-
math
.
log
(
dt_min
))
+
math
.
log
(
dt_min
)
)
dt
=
torch
.
clamp
(
dt
,
min
=
dt_init_floor
)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt
=
dt
+
torch
.
log
(
-
torch
.
expm1
(
-
dt
))
self
.
dt_bias
=
nn
.
Parameter
(
inv_dt
)
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self
.
dt_bias
.
_no_weight_decay
=
True
assert
A_init_range
[
0
]
>
0
and
A_init_range
[
1
]
>=
A_init_range
[
0
]
A
=
torch
.
empty
(
self
.
nheads
,
dtype
=
torch
.
float32
,
device
=
device
).
uniform_
(
*
A_init_range
)
A_log
=
torch
.
log
(
A
).
to
(
dtype
=
dtype
)
self
.
A_log
=
nn
.
Parameter
(
A_log
)
self
.
A_log
.
_no_weight_decay
=
True
# D "skip" parameter
self
.
D
=
nn
.
Parameter
(
torch
.
ones
(
self
.
d_ssm
if
self
.
D_has_hdim
else
self
.
nheads
,
device
=
device
))
self
.
D
.
_no_weight_decay
=
True
if
self
.
rmsnorm
:
assert
RMSNormGated
is
not
None
self
.
norm
=
RMSNormGated
(
self
.
d_ssm
,
eps
=
1e-5
,
norm_before_gate
=
self
.
norm_before_gate
,
group_size
=
self
.
d_ssm
//
ngroups
,
**
factory_kwargs
)
if
self
.
process_group
is
None
:
self
.
out_proj
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
,
bias
=
bias
,
**
factory_kwargs
)
else
:
self
.
out_proj
=
RowParallelLinear
(
self
.
d_inner
*
self
.
world_size
,
self
.
d_model
,
bias
=
bias
,
process_group
=
self
.
process_group
,
sequence_parallel
=
self
.
sequence_parallel
,
**
factory_kwargs
)
def
forward
(
self
,
u
,
seqlen
=
None
,
seq_idx
=
None
,
cu_seqlens
=
None
,
inference_params
=
None
):
"""
u: (batch, seqlen, hidden_dim) if seqlen=None.
If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
split u during sequence parallel, we split the batch * seqlen dimension
(in case batch is small).
Returns: same shape as u
"""
seqlen_og
=
seqlen
if
seqlen
is
None
:
batch
,
seqlen
,
dim
=
u
.
shape
else
:
batch_seqlen
,
dim
=
u
.
shape
batch
=
batch_seqlen
//
seqlen
conv_state
,
ssm_state
=
None
,
None
if
inference_params
is
not
None
:
inference_batch
=
cu_seqlens
.
shape
[
0
]
-
1
if
cu_seqlens
is
not
None
else
batch
conv_state
,
ssm_state
=
self
.
_get_states_from_cache
(
inference_params
,
inference_batch
)
if
inference_params
.
seqlen_offset
>
0
:
# The states are updated inplace
out
,
_
,
_
=
self
.
step
(
u
,
conv_state
,
ssm_state
)
return
out
zxbcdt
=
self
.
in_proj
(
u
)
# (B, L, d_in_proj) or (B * L, d_in_proj)
if
seqlen_og
is
not
None
:
zxbcdt
=
rearrange
(
zxbcdt
,
"(b l) d -> b l d"
,
l
=
seqlen
)
# If the model is loaded in fp16, without the .float() here, A might be -inf
A
=
-
torch
.
exp
(
self
.
A_log
.
float
())
# (nheads) or (d_inner, d_state)
dt_limit_kwargs
=
{}
if
self
.
dt_limit
==
(
0.0
,
float
(
"inf"
))
else
dict
(
dt_limit
=
self
.
dt_limit
)
if
self
.
use_mem_eff_path
and
inference_params
is
None
:
out
=
mamba_split_conv1d_scan_combined
(
zxbcdt
,
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
self
.
conv1d
.
bias
,
self
.
dt_bias
,
A
,
D
=
rearrange
(
self
.
D
,
"(h p) -> h p"
,
p
=
self
.
headdim
)
if
self
.
D_has_hdim
else
self
.
D
,
chunk_size
=
self
.
chunk_size
,
seq_idx
=
seq_idx
,
activation
=
self
.
activation
,
rmsnorm_weight
=
self
.
norm
.
weight
if
self
.
rmsnorm
else
None
,
rmsnorm_eps
=
self
.
norm
.
eps
if
self
.
rmsnorm
else
1e-6
,
outproj_weight
=
self
.
out_proj
.
weight
,
outproj_bias
=
self
.
out_proj
.
bias
,
headdim
=
None
if
self
.
D_has_hdim
else
self
.
headdim
,
ngroups
=
self
.
ngroups
,
norm_before_gate
=
self
.
norm_before_gate
,
**
dt_limit_kwargs
,
)
if
seqlen_og
is
not
None
:
out
=
rearrange
(
out
,
"b l d -> (b l) d"
)
if
self
.
process_group
is
not
None
:
reduce_fn
=
reduce_scatter
if
self
.
sequence_parallel
else
all_reduce
out
=
reduce_fn
(
out
,
self
.
process_group
)
else
:
d_mlp
=
(
zxbcdt
.
shape
[
-
1
]
-
2
*
self
.
d_ssm
-
2
*
self
.
ngroups
*
self
.
d_state
-
self
.
nheads
)
//
2
z0
,
x0
,
z
,
xBC
,
dt
=
torch
.
split
(
zxbcdt
,
[
d_mlp
,
d_mlp
,
self
.
d_ssm
,
self
.
d_ssm
+
2
*
self
.
ngroups
*
self
.
d_state
,
self
.
nheads
],
dim
=-
1
)
if
conv_state
is
not
None
:
if
cu_seqlens
is
None
:
# If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
xBC_t
=
rearrange
(
xBC
,
"b l d -> b d l"
)
conv_state
.
copy_
(
F
.
pad
(
xBC_t
,
(
self
.
d_conv
-
xBC_t
.
shape
[
-
1
],
0
)))
# Update state (B D W)
else
:
assert
causal_conv1d_varlen_states
is
not
None
,
"varlen inference requires causal_conv1d package"
assert
batch
==
1
,
"varlen inference only supports batch dimension 1"
conv_varlen_states
=
causal_conv1d_varlen_states
(
xBC
.
squeeze
(
0
),
cu_seqlens
,
state_len
=
conv_state
.
shape
[
-
1
]
)
conv_state
.
copy_
(
conv_varlen_states
)
assert
self
.
activation
in
[
"silu"
,
"swish"
]
if
causal_conv1d_fn
is
None
or
self
.
activation
not
in
[
"silu"
,
"swish"
]:
assert
seq_idx
is
None
,
"varlen conv1d requires the causal_conv1d package"
xBC
=
self
.
act
(
self
.
conv1d
(
xBC
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
)[:,
-
(
self
.
dconv
-
1
):]
)
# (B, L, self.d_ssm + 2 * ngroups * d_state)
else
:
xBC
=
causal_conv1d_fn
(
xBC
.
transpose
(
1
,
2
),
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
bias
=
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
seq_idx
=
seq_idx
,
).
transpose
(
1
,
2
)
x
,
B
,
C
=
torch
.
split
(
xBC
,
[
self
.
d_ssm
,
self
.
ngroups
*
self
.
d_state
,
self
.
ngroups
*
self
.
d_state
],
dim
=-
1
)
y
=
mamba_chunk_scan_combined
(
rearrange
(
x
,
"b l (h p) -> b l h p"
,
p
=
self
.
headdim
),
dt
,
A
,
rearrange
(
B
,
"b l (g n) -> b l g n"
,
g
=
self
.
ngroups
),
rearrange
(
C
,
"b l (g n) -> b l g n"
,
g
=
self
.
ngroups
),
chunk_size
=
self
.
chunk_size
,
D
=
rearrange
(
self
.
D
,
"(h p) -> h p"
,
p
=
self
.
headdim
)
if
self
.
D_has_hdim
else
self
.
D
,
z
=
rearrange
(
z
,
"b l (h p) -> b l h p"
,
p
=
self
.
headdim
)
if
not
self
.
rmsnorm
else
None
,
dt_bias
=
self
.
dt_bias
,
dt_softplus
=
True
,
seq_idx
=
seq_idx
,
cu_seqlens
=
cu_seqlens
,
**
dt_limit_kwargs
,
return_final_states
=
ssm_state
is
not
None
,
return_varlen_states
=
cu_seqlens
is
not
None
and
inference_params
is
not
None
,
)
if
ssm_state
is
not
None
:
y
,
last_state
,
*
rest
=
y
if
cu_seqlens
is
None
:
ssm_state
.
copy_
(
last_state
)
else
:
varlen_states
=
rest
[
0
]
ssm_state
.
copy_
(
varlen_states
)
y
=
rearrange
(
y
,
"b l h p -> b l (h p)"
)
if
self
.
rmsnorm
:
y
=
self
.
norm
(
y
,
z
)
if
d_mlp
>
0
:
y
=
torch
.
cat
([
F
.
silu
(
z0
)
*
x0
,
y
],
dim
=-
1
)
if
seqlen_og
is
not
None
:
y
=
rearrange
(
y
,
"b l d -> (b l) d"
)
out
=
self
.
out_proj
(
y
)
return
out
def
step
(
self
,
hidden_states
,
conv_state
,
ssm_state
):
dtype
=
hidden_states
.
dtype
assert
hidden_states
.
shape
[
1
]
==
1
,
"Only support decoding with 1 token at a time for now"
zxbcdt
=
self
.
in_proj
(
hidden_states
.
squeeze
(
1
))
# (B 2D)
d_mlp
=
(
zxbcdt
.
shape
[
-
1
]
-
2
*
self
.
d_ssm
-
2
*
self
.
ngroups
*
self
.
d_state
-
self
.
nheads
)
//
2
z0
,
x0
,
z
,
xBC
,
dt
=
torch
.
split
(
zxbcdt
,
[
d_mlp
,
d_mlp
,
self
.
d_ssm
,
self
.
d_ssm
+
2
*
self
.
ngroups
*
self
.
d_state
,
self
.
nheads
],
dim
=-
1
)
# Conv step
if
causal_conv1d_update
is
None
:
conv_state
.
copy_
(
torch
.
roll
(
conv_state
,
shifts
=-
1
,
dims
=-
1
))
# Update state (B D W)
conv_state
[:,
:,
-
1
]
=
xBC
xBC
=
torch
.
sum
(
conv_state
*
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
dim
=-
1
)
# (B D)
if
self
.
conv1d
.
bias
is
not
None
:
xBC
=
xBC
+
self
.
conv1d
.
bias
xBC
=
self
.
act
(
xBC
).
to
(
dtype
=
dtype
)
else
:
xBC
=
causal_conv1d_update
(
xBC
,
conv_state
,
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
self
.
conv1d
.
bias
,
self
.
activation
,
)
x
,
B
,
C
=
torch
.
split
(
xBC
,
[
self
.
d_ssm
,
self
.
ngroups
*
self
.
d_state
,
self
.
ngroups
*
self
.
d_state
],
dim
=-
1
)
A
=
-
torch
.
exp
(
self
.
A_log
.
float
())
# (nheads,)
# SSM step
if
selective_state_update
is
None
:
assert
self
.
ngroups
==
1
,
"Only support ngroups=1 for this inference code path"
# Discretize A and B
dt
=
F
.
softplus
(
dt
+
self
.
dt_bias
.
to
(
dtype
=
dt
.
dtype
))
# (batch, nheads)
dA
=
torch
.
exp
(
dt
*
A
)
# (batch, nheads)
x
=
rearrange
(
x
,
"b (h p) -> b h p"
,
p
=
self
.
headdim
)
dBx
=
torch
.
einsum
(
"bh,bn,bhp->bhpn"
,
dt
,
B
,
x
)
ssm_state
.
copy_
(
ssm_state
*
rearrange
(
dA
,
"b h -> b h 1 1"
)
+
dBx
)
y
=
torch
.
einsum
(
"bhpn,bn->bhp"
,
ssm_state
.
to
(
dtype
),
C
)
y
=
y
+
rearrange
(
self
.
D
.
to
(
dtype
),
"h -> h 1"
)
*
x
y
=
rearrange
(
y
,
"b h p -> b (h p)"
)
if
not
self
.
rmsnorm
:
y
=
y
*
self
.
act
(
z
)
# (B D)
else
:
A
=
repeat
(
A
,
"h -> h p n"
,
p
=
self
.
headdim
,
n
=
self
.
d_state
).
to
(
dtype
=
torch
.
float32
)
dt
=
repeat
(
dt
,
"b h -> b h p"
,
p
=
self
.
headdim
)
dt_bias
=
repeat
(
self
.
dt_bias
,
"h -> h p"
,
p
=
self
.
headdim
)
D
=
repeat
(
self
.
D
,
"h -> h p"
,
p
=
self
.
headdim
)
B
=
rearrange
(
B
,
"b (g n) -> b g n"
,
g
=
self
.
ngroups
)
C
=
rearrange
(
C
,
"b (g n) -> b g n"
,
g
=
self
.
ngroups
)
x_reshaped
=
rearrange
(
x
,
"b (h p) -> b h p"
,
p
=
self
.
headdim
)
if
not
self
.
rmsnorm
:
z
=
rearrange
(
z
,
"b (h p) -> b h p"
,
p
=
self
.
headdim
)
y
=
selective_state_update
(
ssm_state
,
x_reshaped
,
dt
,
A
,
B
,
C
,
D
,
z
=
z
if
not
self
.
rmsnorm
else
None
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
)
y
=
rearrange
(
y
,
"b h p -> b (h p)"
)
if
self
.
rmsnorm
:
y
=
self
.
norm
(
y
,
z
)
if
d_mlp
>
0
:
y
=
torch
.
cat
([
F
.
silu
(
z0
)
*
x0
,
y
],
dim
=-
1
)
out
=
self
.
out_proj
(
y
)
return
out
.
unsqueeze
(
1
),
conv_state
,
ssm_state
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
device
=
self
.
out_proj
.
weight
.
device
conv_dtype
=
self
.
conv1d
.
weight
.
dtype
if
dtype
is
None
else
dtype
conv_state
=
torch
.
zeros
(
batch_size
,
self
.
d_conv
,
self
.
conv1d
.
weight
.
shape
[
0
],
device
=
device
,
dtype
=
conv_dtype
).
transpose
(
1
,
2
)
ssm_dtype
=
self
.
in_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
ssm_state
=
torch
.
zeros
(
batch_size
,
self
.
nheads
,
self
.
headdim
,
self
.
d_state
,
device
=
device
,
dtype
=
ssm_dtype
)
return
conv_state
,
ssm_state
def
_get_states_from_cache
(
self
,
inference_params
,
batch_size
,
initialize_states
=
False
):
assert
self
.
layer_idx
is
not
None
if
self
.
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
batch_shape
=
(
batch_size
,)
conv_state
=
torch
.
zeros
(
batch_size
,
self
.
d_conv
,
self
.
conv1d
.
weight
.
shape
[
0
],
device
=
self
.
conv1d
.
weight
.
device
,
dtype
=
self
.
conv1d
.
weight
.
dtype
,
).
transpose
(
1
,
2
)
ssm_state
=
torch
.
zeros
(
batch_size
,
self
.
nheads
,
self
.
headdim
,
self
.
d_state
,
device
=
self
.
in_proj
.
weight
.
device
,
dtype
=
self
.
in_proj
.
weight
.
dtype
,
)
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
(
conv_state
,
ssm_state
)
else
:
conv_state
,
ssm_state
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
# TODO: What if batch size changes between generation, and we reuse the same states?
if
initialize_states
:
conv_state
.
zero_
()
ssm_state
.
zero_
()
return
conv_state
,
ssm_state
mamba/mamba_ssm/modules/mamba2_simple.py
0 → 100644
View file @
2eefe3d6
# Copyright (c) 2024, Tri Dao, Albert Gu.
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
try
:
from
causal_conv1d
import
causal_conv1d_fn
except
ImportError
:
causal_conv1d_fn
=
None
try
:
from
mamba_ssm.ops.triton.layernorm_gated
import
RMSNorm
as
RMSNormGated
,
LayerNorm
except
ImportError
:
RMSNormGated
,
LayerNorm
=
None
,
None
from
mamba_ssm.ops.triton.ssd_combined
import
mamba_chunk_scan_combined
from
mamba_ssm.ops.triton.ssd_combined
import
mamba_split_conv1d_scan_combined
class
Mamba2Simple
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_state
=
64
,
d_conv
=
4
,
conv_init
=
None
,
expand
=
2
,
headdim
=
128
,
ngroups
=
1
,
A_init_range
=
(
1
,
16
),
dt_min
=
0.001
,
dt_max
=
0.1
,
dt_init_floor
=
1e-4
,
dt_limit
=
(
0.0
,
float
(
"inf"
)),
learnable_init_states
=
False
,
activation
=
"swish"
,
bias
=
False
,
conv_bias
=
True
,
# Fused kernel and sharding options
chunk_size
=
256
,
use_mem_eff_path
=
True
,
layer_idx
=
None
,
# Absorb kwarg for general module
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
d_model
=
d_model
self
.
d_state
=
d_state
self
.
d_conv
=
d_conv
self
.
conv_init
=
conv_init
self
.
expand
=
expand
self
.
d_inner
=
self
.
expand
*
self
.
d_model
self
.
headdim
=
headdim
self
.
ngroups
=
ngroups
assert
self
.
d_inner
%
self
.
headdim
==
0
self
.
nheads
=
self
.
d_inner
//
self
.
headdim
self
.
dt_limit
=
dt_limit
self
.
learnable_init_states
=
learnable_init_states
self
.
activation
=
activation
self
.
chunk_size
=
chunk_size
self
.
use_mem_eff_path
=
use_mem_eff_path
self
.
layer_idx
=
layer_idx
# Order: [z, x, B, C, dt]
d_in_proj
=
2
*
self
.
d_inner
+
2
*
self
.
ngroups
*
self
.
d_state
+
self
.
nheads
self
.
in_proj
=
nn
.
Linear
(
self
.
d_model
,
d_in_proj
,
bias
=
bias
,
**
factory_kwargs
)
conv_dim
=
self
.
d_inner
+
2
*
self
.
ngroups
*
self
.
d_state
self
.
conv1d
=
nn
.
Conv1d
(
in_channels
=
conv_dim
,
out_channels
=
conv_dim
,
bias
=
conv_bias
,
kernel_size
=
d_conv
,
groups
=
conv_dim
,
padding
=
d_conv
-
1
,
**
factory_kwargs
,
)
if
self
.
conv_init
is
not
None
:
nn
.
init
.
uniform_
(
self
.
conv1d
.
weight
,
-
self
.
conv_init
,
self
.
conv_init
)
# self.conv1d.weight._no_weight_decay = True
if
self
.
learnable_init_states
:
self
.
init_states
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
nheads
,
self
.
headdim
,
self
.
d_state
,
**
factory_kwargs
))
self
.
init_states
.
_no_weight_decay
=
True
self
.
act
=
nn
.
SiLU
()
# Initialize log dt bias
dt
=
torch
.
exp
(
torch
.
rand
(
self
.
nheads
,
**
factory_kwargs
)
*
(
math
.
log
(
dt_max
)
-
math
.
log
(
dt_min
))
+
math
.
log
(
dt_min
)
)
dt
=
torch
.
clamp
(
dt
,
min
=
dt_init_floor
)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt
=
dt
+
torch
.
log
(
-
torch
.
expm1
(
-
dt
))
self
.
dt_bias
=
nn
.
Parameter
(
inv_dt
)
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self
.
dt_bias
.
_no_weight_decay
=
True
# A parameter
assert
A_init_range
[
0
]
>
0
and
A_init_range
[
1
]
>=
A_init_range
[
0
]
A
=
torch
.
empty
(
self
.
nheads
,
dtype
=
torch
.
float32
,
device
=
device
).
uniform_
(
*
A_init_range
)
A_log
=
torch
.
log
(
A
).
to
(
dtype
=
dtype
)
self
.
A_log
=
nn
.
Parameter
(
A_log
)
# self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
self
.
A_log
.
_no_weight_decay
=
True
# D "skip" parameter
self
.
D
=
nn
.
Parameter
(
torch
.
ones
(
self
.
nheads
,
device
=
device
))
self
.
D
.
_no_weight_decay
=
True
# Extra normalization layer right before output projection
assert
RMSNormGated
is
not
None
self
.
norm
=
RMSNormGated
(
self
.
d_inner
,
eps
=
1e-5
,
norm_before_gate
=
False
,
**
factory_kwargs
)
self
.
out_proj
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
,
bias
=
bias
,
**
factory_kwargs
)
def
forward
(
self
,
u
,
seq_idx
=
None
):
"""
u: (B, L, D)
Returns: same shape as u
"""
batch
,
seqlen
,
dim
=
u
.
shape
zxbcdt
=
self
.
in_proj
(
u
)
# (B, L, d_in_proj)
A
=
-
torch
.
exp
(
self
.
A_log
)
# (nheads) or (d_inner, d_state)
initial_states
=
repeat
(
self
.
init_states
,
"... -> b ..."
,
b
=
batch
)
if
self
.
learnable_init_states
else
None
dt_limit_kwargs
=
{}
if
self
.
dt_limit
==
(
0.0
,
float
(
"inf"
))
else
dict
(
dt_limit
=
self
.
dt_limit
)
if
self
.
use_mem_eff_path
:
# Fully fused path
out
=
mamba_split_conv1d_scan_combined
(
zxbcdt
,
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
self
.
conv1d
.
bias
,
self
.
dt_bias
,
A
,
D
=
self
.
D
,
chunk_size
=
self
.
chunk_size
,
seq_idx
=
seq_idx
,
activation
=
self
.
activation
,
rmsnorm_weight
=
self
.
norm
.
weight
,
rmsnorm_eps
=
self
.
norm
.
eps
,
outproj_weight
=
self
.
out_proj
.
weight
,
outproj_bias
=
self
.
out_proj
.
bias
,
headdim
=
self
.
headdim
,
ngroups
=
self
.
ngroups
,
norm_before_gate
=
False
,
initial_states
=
initial_states
,
**
dt_limit_kwargs
,
)
else
:
z
,
xBC
,
dt
=
torch
.
split
(
zxbcdt
,
[
self
.
d_inner
,
self
.
d_inner
+
2
*
self
.
ngroups
*
self
.
d_state
,
self
.
nheads
],
dim
=-
1
)
dt
=
F
.
softplus
(
dt
+
self
.
dt_bias
)
# (B, L, nheads)
assert
self
.
activation
in
[
"silu"
,
"swish"
]
# 1D Convolution
if
causal_conv1d_fn
is
None
or
self
.
activation
not
in
[
"silu"
,
"swish"
]:
xBC
=
self
.
act
(
self
.
conv1d
(
xBC
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
)
)
# (B, L, self.d_inner + 2 * ngroups * d_state)
xBC
=
xBC
[:,
:
seqlen
,
:]
else
:
xBC
=
causal_conv1d_fn
(
x
=
xBC
.
transpose
(
1
,
2
),
weight
=
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
bias
=
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
).
transpose
(
1
,
2
)
# Split into 3 main branches: X, B, C
# These correspond to V, K, Q respectively in the SSM/attention duality
x
,
B
,
C
=
torch
.
split
(
xBC
,
[
self
.
d_inner
,
self
.
ngroups
*
self
.
d_state
,
self
.
ngroups
*
self
.
d_state
],
dim
=-
1
)
y
=
mamba_chunk_scan_combined
(
rearrange
(
x
,
"b l (h p) -> b l h p"
,
p
=
self
.
headdim
),
dt
,
A
,
rearrange
(
B
,
"b l (g n) -> b l g n"
,
g
=
self
.
ngroups
),
rearrange
(
C
,
"b l (g n) -> b l g n"
,
g
=
self
.
ngroups
),
chunk_size
=
self
.
chunk_size
,
D
=
self
.
D
,
z
=
None
,
seq_idx
=
seq_idx
,
initial_states
=
initial_states
,
**
dt_limit_kwargs
,
)
y
=
rearrange
(
y
,
"b l h p -> b l (h p)"
)
# Multiply "gate" branch and apply extra normalization layer
y
=
self
.
norm
(
y
,
z
)
out
=
self
.
out_proj
(
y
)
return
out
mamba/mamba_ssm/modules/mamba_simple.py
0 → 100644
View file @
2eefe3d6
# Copyright (c) 2023, Tri Dao, Albert Gu.
import
math
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
einops
import
rearrange
,
repeat
from
mamba_ssm.ops.selective_scan_interface
import
selective_scan_fn
,
mamba_inner_fn
try
:
from
causal_conv1d
import
causal_conv1d_fn
,
causal_conv1d_update
except
ImportError
:
causal_conv1d_fn
,
causal_conv1d_update
=
None
,
None
try
:
from
mamba_ssm.ops.triton.selective_state_update
import
selective_state_update
except
ImportError
:
selective_state_update
=
None
try
:
from
mamba_ssm.ops.triton.layer_norm
import
RMSNorm
,
layer_norm_fn
,
rms_norm_fn
except
ImportError
:
RMSNorm
,
layer_norm_fn
,
rms_norm_fn
=
None
,
None
,
None
class
Mamba
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_state
=
16
,
d_conv
=
4
,
expand
=
2
,
dt_rank
=
"auto"
,
dt_min
=
0.001
,
dt_max
=
0.1
,
dt_init
=
"random"
,
dt_scale
=
1.0
,
dt_init_floor
=
1e-4
,
conv_bias
=
True
,
bias
=
False
,
use_fast_path
=
True
,
# Fused kernel options
layer_idx
=
None
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
d_model
=
d_model
self
.
d_state
=
d_state
self
.
d_conv
=
d_conv
self
.
expand
=
expand
self
.
d_inner
=
int
(
self
.
expand
*
self
.
d_model
)
self
.
dt_rank
=
math
.
ceil
(
self
.
d_model
/
16
)
if
dt_rank
==
"auto"
else
dt_rank
self
.
use_fast_path
=
use_fast_path
self
.
layer_idx
=
layer_idx
self
.
in_proj
=
nn
.
Linear
(
self
.
d_model
,
self
.
d_inner
*
2
,
bias
=
bias
,
**
factory_kwargs
)
self
.
conv1d
=
nn
.
Conv1d
(
in_channels
=
self
.
d_inner
,
out_channels
=
self
.
d_inner
,
bias
=
conv_bias
,
kernel_size
=
d_conv
,
groups
=
self
.
d_inner
,
padding
=
d_conv
-
1
,
**
factory_kwargs
,
)
self
.
activation
=
"silu"
self
.
act
=
nn
.
SiLU
()
self
.
x_proj
=
nn
.
Linear
(
self
.
d_inner
,
self
.
dt_rank
+
self
.
d_state
*
2
,
bias
=
False
,
**
factory_kwargs
)
self
.
dt_proj
=
nn
.
Linear
(
self
.
dt_rank
,
self
.
d_inner
,
bias
=
True
,
**
factory_kwargs
)
# Initialize special dt projection to preserve variance at initialization
dt_init_std
=
self
.
dt_rank
**-
0.5
*
dt_scale
if
dt_init
==
"constant"
:
nn
.
init
.
constant_
(
self
.
dt_proj
.
weight
,
dt_init_std
)
elif
dt_init
==
"random"
:
nn
.
init
.
uniform_
(
self
.
dt_proj
.
weight
,
-
dt_init_std
,
dt_init_std
)
else
:
raise
NotImplementedError
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt
=
torch
.
exp
(
torch
.
rand
(
self
.
d_inner
,
**
factory_kwargs
)
*
(
math
.
log
(
dt_max
)
-
math
.
log
(
dt_min
))
+
math
.
log
(
dt_min
)
).
clamp
(
min
=
dt_init_floor
)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt
=
dt
+
torch
.
log
(
-
torch
.
expm1
(
-
dt
))
with
torch
.
no_grad
():
self
.
dt_proj
.
bias
.
copy_
(
inv_dt
)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
self
.
dt_proj
.
bias
.
_no_reinit
=
True
# S4D real initialization
A
=
repeat
(
torch
.
arange
(
1
,
self
.
d_state
+
1
,
dtype
=
torch
.
float32
,
device
=
device
),
"n -> d n"
,
d
=
self
.
d_inner
,
).
contiguous
()
A_log
=
torch
.
log
(
A
)
# Keep A_log in fp32
self
.
A_log
=
nn
.
Parameter
(
A_log
)
self
.
A_log
.
_no_weight_decay
=
True
# D "skip" parameter
self
.
D
=
nn
.
Parameter
(
torch
.
ones
(
self
.
d_inner
,
device
=
device
))
# Keep in fp32
self
.
D
.
_no_weight_decay
=
True
self
.
out_proj
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
,
bias
=
bias
,
**
factory_kwargs
)
def
forward
(
self
,
hidden_states
,
inference_params
=
None
):
"""
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
batch
,
seqlen
,
dim
=
hidden_states
.
shape
conv_state
,
ssm_state
=
None
,
None
if
inference_params
is
not
None
:
conv_state
,
ssm_state
=
self
.
_get_states_from_cache
(
inference_params
,
batch
)
if
inference_params
.
seqlen_offset
>
0
:
# The states are updated inplace
out
,
_
,
_
=
self
.
step
(
hidden_states
,
conv_state
,
ssm_state
)
return
out
# We do matmul and transpose BLH -> HBL at the same time
xz
=
rearrange
(
self
.
in_proj
.
weight
@
rearrange
(
hidden_states
,
"b l d -> d (b l)"
),
"d (b l) -> b d l"
,
l
=
seqlen
,
)
if
self
.
in_proj
.
bias
is
not
None
:
xz
=
xz
+
rearrange
(
self
.
in_proj
.
bias
.
to
(
dtype
=
xz
.
dtype
),
"d -> d 1"
)
A
=
-
torch
.
exp
(
self
.
A_log
.
float
())
# (d_inner, d_state)
# In the backward pass we write dx and dz next to each other to avoid torch.cat
if
self
.
use_fast_path
and
causal_conv1d_fn
is
not
None
and
inference_params
is
None
:
# Doesn't support outputting the states
out
=
mamba_inner_fn
(
xz
,
self
.
conv1d
.
weight
,
self
.
conv1d
.
bias
,
self
.
x_proj
.
weight
,
self
.
dt_proj
.
weight
,
self
.
out_proj
.
weight
,
self
.
out_proj
.
bias
,
A
,
None
,
# input-dependent B
None
,
# input-dependent C
self
.
D
.
float
(),
delta_bias
=
self
.
dt_proj
.
bias
.
float
(),
delta_softplus
=
True
,
)
else
:
x
,
z
=
xz
.
chunk
(
2
,
dim
=
1
)
# Compute short convolution
if
conv_state
is
not
None
:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
conv_state
.
copy_
(
F
.
pad
(
x
,
(
self
.
d_conv
-
x
.
shape
[
-
1
],
0
)))
# Update state (B D W)
if
causal_conv1d_fn
is
None
:
x
=
self
.
act
(
self
.
conv1d
(
x
)[...,
:
seqlen
])
else
:
assert
self
.
activation
in
[
"silu"
,
"swish"
]
x
=
causal_conv1d_fn
(
x
=
x
,
weight
=
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
bias
=
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
)
# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl
=
self
.
x_proj
(
rearrange
(
x
,
"b d l -> (b l) d"
))
# (bl d)
dt
,
B
,
C
=
torch
.
split
(
x_dbl
,
[
self
.
dt_rank
,
self
.
d_state
,
self
.
d_state
],
dim
=-
1
)
dt
=
self
.
dt_proj
.
weight
@
dt
.
t
()
dt
=
rearrange
(
dt
,
"d (b l) -> b d l"
,
l
=
seqlen
)
B
=
rearrange
(
B
,
"(b l) dstate -> b dstate l"
,
l
=
seqlen
).
contiguous
()
C
=
rearrange
(
C
,
"(b l) dstate -> b dstate l"
,
l
=
seqlen
).
contiguous
()
assert
self
.
activation
in
[
"silu"
,
"swish"
]
y
=
selective_scan_fn
(
x
,
dt
,
A
,
B
,
C
,
self
.
D
.
float
(),
z
=
z
,
delta_bias
=
self
.
dt_proj
.
bias
.
float
(),
delta_softplus
=
True
,
return_last_state
=
ssm_state
is
not
None
,
)
if
ssm_state
is
not
None
:
y
,
last_state
=
y
ssm_state
.
copy_
(
last_state
)
y
=
rearrange
(
y
,
"b d l -> b l d"
)
out
=
self
.
out_proj
(
y
)
return
out
def
step
(
self
,
hidden_states
,
conv_state
,
ssm_state
):
dtype
=
hidden_states
.
dtype
assert
hidden_states
.
shape
[
1
]
==
1
,
"Only support decoding with 1 token at a time for now"
xz
=
self
.
in_proj
(
hidden_states
.
squeeze
(
1
))
# (B 2D)
x
,
z
=
xz
.
chunk
(
2
,
dim
=-
1
)
# (B D)
# Conv step
if
causal_conv1d_update
is
None
:
conv_state
.
copy_
(
torch
.
roll
(
conv_state
,
shifts
=-
1
,
dims
=-
1
))
# Update state (B D W)
conv_state
[:,
:,
-
1
]
=
x
x
=
torch
.
sum
(
conv_state
*
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
dim
=-
1
)
# (B D)
if
self
.
conv1d
.
bias
is
not
None
:
x
=
x
+
self
.
conv1d
.
bias
x
=
self
.
act
(
x
).
to
(
dtype
=
dtype
)
else
:
x
=
causal_conv1d_update
(
x
,
conv_state
,
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
self
.
conv1d
.
bias
,
self
.
activation
,
)
x_db
=
self
.
x_proj
(
x
)
# (B dt_rank+2*d_state)
dt
,
B
,
C
=
torch
.
split
(
x_db
,
[
self
.
dt_rank
,
self
.
d_state
,
self
.
d_state
],
dim
=-
1
)
# Don't add dt_bias here
dt
=
F
.
linear
(
dt
,
self
.
dt_proj
.
weight
)
# (B d_inner)
A
=
-
torch
.
exp
(
self
.
A_log
.
float
())
# (d_inner, d_state)
# SSM step
if
selective_state_update
is
None
:
# Discretize A and B
dt
=
F
.
softplus
(
dt
+
self
.
dt_proj
.
bias
.
to
(
dtype
=
dt
.
dtype
))
dA
=
torch
.
exp
(
torch
.
einsum
(
"bd,dn->bdn"
,
dt
,
A
))
dB
=
torch
.
einsum
(
"bd,bn->bdn"
,
dt
,
B
)
ssm_state
.
copy_
(
ssm_state
*
dA
+
rearrange
(
x
,
"b d -> b d 1"
)
*
dB
)
y
=
torch
.
einsum
(
"bdn,bn->bd"
,
ssm_state
.
to
(
dtype
),
C
)
y
=
y
+
self
.
D
.
to
(
dtype
)
*
x
y
=
y
*
self
.
act
(
z
)
# (B D)
else
:
y
=
selective_state_update
(
ssm_state
,
x
,
dt
,
A
,
B
,
C
,
self
.
D
,
z
=
z
,
dt_bias
=
self
.
dt_proj
.
bias
,
dt_softplus
=
True
)
out
=
self
.
out_proj
(
y
)
return
out
.
unsqueeze
(
1
),
conv_state
,
ssm_state
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
device
=
self
.
out_proj
.
weight
.
device
conv_dtype
=
self
.
conv1d
.
weight
.
dtype
if
dtype
is
None
else
dtype
conv_state
=
torch
.
zeros
(
batch_size
,
self
.
d_model
*
self
.
expand
,
self
.
d_conv
,
device
=
device
,
dtype
=
conv_dtype
)
ssm_dtype
=
self
.
dt_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
# ssm_dtype = torch.float32
ssm_state
=
torch
.
zeros
(
batch_size
,
self
.
d_model
*
self
.
expand
,
self
.
d_state
,
device
=
device
,
dtype
=
ssm_dtype
)
return
conv_state
,
ssm_state
def
_get_states_from_cache
(
self
,
inference_params
,
batch_size
,
initialize_states
=
False
):
assert
self
.
layer_idx
is
not
None
if
self
.
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
batch_shape
=
(
batch_size
,)
conv_state
=
torch
.
zeros
(
batch_size
,
self
.
d_model
*
self
.
expand
,
self
.
d_conv
,
device
=
self
.
conv1d
.
weight
.
device
,
dtype
=
self
.
conv1d
.
weight
.
dtype
,
)
ssm_state
=
torch
.
zeros
(
batch_size
,
self
.
d_model
*
self
.
expand
,
self
.
d_state
,
device
=
self
.
dt_proj
.
weight
.
device
,
dtype
=
self
.
dt_proj
.
weight
.
dtype
,
# dtype=torch.float32,
)
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
(
conv_state
,
ssm_state
)
else
:
conv_state
,
ssm_state
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
# TODO: What if batch size changes between generation, and we reuse the same states?
if
initialize_states
:
conv_state
.
zero_
()
ssm_state
.
zero_
()
return
conv_state
,
ssm_state
mamba/mamba_ssm/modules/mha.py
0 → 100644
View file @
2eefe3d6
This diff is collapsed.
Click to expand it.
mamba/mamba_ssm/modules/mlp.py
0 → 100644
View file @
2eefe3d6
# Copyright (c) 2024, Tri Dao, Albert Gu.
from
torch
import
nn
from
torch.nn
import
functional
as
F
class
GatedMLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
silu
,
bias
=
False
,
multiple_of
=
128
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
(
hidden_features
if
hidden_features
is
not
None
else
int
(
8
*
in_features
/
3
)
)
hidden_features
=
(
hidden_features
+
multiple_of
-
1
)
//
multiple_of
*
multiple_of
self
.
fc1
=
nn
.
Linear
(
in_features
,
2
*
hidden_features
,
bias
=
bias
,
**
factory_kwargs
)
self
.
activation
=
activation
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
y
,
gate
=
y
.
chunk
(
2
,
dim
=-
1
)
y
=
y
*
self
.
activation
(
gate
)
y
=
self
.
fc2
(
y
)
return
y
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment