Commit ac192496 authored by hepj's avatar hepj
Browse files

Update 图片5.jpg, 图片4.jpg, 图片1.png, 图片2.jpg, 图片3.jpg, AUTHORS, inf.sh, LICENSE,...

Update 图片5.jpg, 图片4.jpg, 图片1.png, 图片2.jpg, 图片3.jpg, AUTHORS, inf.sh, LICENSE, model.properties, README.md.old, README.md, run.sh, requirements.txt, run-eval.sh, train-demo.py, setup.py files
parents
Pipeline #1413 canceled with stages
Tri Dao, tri@tridao.me
Albert Gu, agu@andrew.cmu.edu
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2023 Tri Dao, Albert Gu
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# MAMBA
## 论文
Mamba: Linear-Time Sequence Modeling with Selective State Spaces
https://arxiv.org/abs/2312.00752
## 模型结构
![图片1](图片1.png)
## 算法原理
![图片3](图片3.jpg)
![图片2](图片2.jpg)
```
原始的SSM通过A B C H矩阵的变化得到y,但是由于不同的输入X都对应相同的A B C H矩阵,因此MAMBA模型对x进行处理得到不同的A' B' C' H'矩阵来作为mamba模块
```
依赖库
pip install wheel,transformers==4.39 modelscope oss2==2.18.6 datasets==2.18.0 lm-eval==0.4.3 datasets==2.18.0
```
/usr/local/lib/python3.10/site-packages/lm_eval/api/metrics.py", line 168
exact_match = hf_evaluate.load("exact_match")改为:
改成hepj/evaluate-0.4.2/metrics/exact_match/exact_match.py
exact_match = hf_evaluate.load("hepj/evaluate-0.4.2/metrics/exact_match/exact_match.py")
然后好像可以下载到数据集了???
缓存地址为/root/.cache/huggingface/modules/datasets_modules/datasets/hellaswag/362ac471216900f3f7c021863caac4eb7886347d0f76d90b6b4361f59ffea4d7
下载好的缓存路径./.cache/huggingface/modules/datasets_modules/datasets/hellaswag.lock
./.cache/huggingface/modules/datasets_modules/datasets/hellaswag
./.cache/huggingface/modules/datasets_modules/datasets/hellaswag/362ac471216900f3f7c021863caac4eb7886347d0f76d90b6b4361f59ffea4d7/__pycache__/hellaswag.cpython-310.pyc
./.cache/huggingface/modules/datasets_modules/datasets/hellaswag/362ac471216900f3f7c021863caac4eb7886347d0f76d90b6b4361f59ffea4d7/hellaswag.py
./.cache/huggingface/modules/datasets_modules/datasets/hellaswag/362ac471216900f3f7c021863caac4eb7886347d0f76d90b6b4361f59ffea4d7/hellaswag.json
./.cache/huggingface/datasets/hellaswag
./.cache/huggingface/datasets/hellaswag/default/0.1.0/362ac471216900f3f7c021863caac4eb7886347d0f76d90b6b4361f59ffea4d7/hellaswag-test.arrow
./.cache/huggingface/datasets/hellaswag/default/0.1.0/362ac471216900f3f7c021863caac4eb7886347d0f76d90b6b4361f59ffea4d7/hellaswag-validation.arrow
./.cache/huggingface/datasets/hellaswag/default/0.1.0/362ac471216900f3f7c021863caac4eb7886347d0f76d90b6b4361f59ffea4d7/hellaswag-train.arrow
```
## 环境配置
### Docker(方式一)
推荐使用docker方式运行,提供[光源](https://www.sourcefind.cn/#/main-page)拉取的docker镜像:
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10
docker run -dit --network=host --name=mamba_pytorch --privileged --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size=16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root --ulimit stack=-1:-1 --ulimit memlock=-1:-1 image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10
docker exec -it mamba_pytorch /bin/bash
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
```
### Dockerfile(方式二)
```
docker build -t mamba:latest .
docker run -dit --network=host --name=mamba_pytorch --privileged --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size=16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root --ulimit stack=-1:-1 --ulimit memlock=-1:-1 qwen:latest
docker exec -it mamba_pytorch /bin/bash
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
```
### conda(方式三)
```
conda create -n mamba_pytorch python=3.10
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
```
### 注意
requirements.txt中@http之类的三方库主要依赖torch,可以从[光合开发者社区](https://cancon.hpccube.com:65024/4/main/)下载对应版本的依赖库
按照上述方式配置好环境以后还需要编译&安装依赖项
```
cd causal-conv1d-1.4.0
python setup.py install
cd ..
python setup.py install
```
## 数据集
这里直接通过lm_eval从hugginface下载可能有问题,可以先通过modelscope获得hub镜像源后再通过原始代码下载
```
from modelscope.msdatasets import MsDataset
ds = MsDataset.load('modelscope/hellaswag', subset_name='default')
```
```
#数据集目录
hellaswag/
└── default
└── 0.1.0
├── 362ac471216900f3f7c021863caac4eb7886347d0f76d90b6b4361f59ffea4d7
│ ├── cache-081f361bf081c0bf.arrow
│ ├── cache-5d43362a7601c065.arrow
│ ├── dataset_info.json
│ ├── hellaswag-test.arrow
│ ├── hellaswag-train.arrow
│ └── hellaswag-validation.arrow
├── 362ac471216900f3f7c021863caac4eb7886347d0f76d90b6b4361f59ffea4d7_builder.lock
└── 362ac471216900f3f7c021863caac4eb7886347d0f76d90b6b4361f59ffea4d7.incomplete_info.lock
```
## 模型下载
[官方mamaba与mamba2模型](https://huggingface.co/state-spaces)
目前mamba已经支持transformer,可以只下载带-hf后缀的模型进行使用,如果使用非hf格式的模型,还需要下载EleutherAI/gpt-neox-20b模型的tokenizer才可以使用
## 训练
```
#官方这里只给了训练的样例,没有更具体的例子
./run.sh
#验证 这里使用hellaswag任务作为示例,可选有lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa
./run-eval.sh
```
## 推理
```
#可以通过传递--promt参数来进行具体的问答,也可以传递--batch参数来进行批量推理
./inf.sh 其中--model-name为hugginface模型名或者本地模型路径
```
## result
![图片4](图片4.jpg)
## 精度
| eval | acc | acc_norm |
| :-----------------------------------: | :-------------: | :-------------: |
| 数据集hellaswag,参数 batch64 float16 | 0.3080 ± 0.0046 | 0.3525 ± 0.0048 |
![图片5](图片5.jpg)
## 应用场景
### 算法类别
`对话问答`
### 热点应用行业
`科研,教育,政府,金融`
## 源码仓库及问题反馈
https://developer.hpccube.com/codes/modelzoo/mamba_pytorch
## 参考资料
https://github.com/state-spaces/mamba/tree/main
# Mamba
![Mamba](assets/selection.png "Selective State Space")
> **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\
> Albert Gu*, Tri Dao*\
> Paper: https://arxiv.org/abs/2312.00752
![Mamba-2](assets/ssd_algorithm.png "State Space Dual Model")
> **Transformers are SSMs: Generalized Models and Efficient Algorithms**\
> **Through Structured State Space Duality**\
> Tri Dao*, Albert Gu*\
> Paper: https://arxiv.org/abs/2405.21060
## About
Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers.
It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4),
with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
## Installation
- [Option] `pip install causal-conv1d>=1.4.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
- `pip install mamba-ssm`: the core Mamba package.
It can also be built from source with `pip install .` from this repository.
If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`.
Other requirements:
- Linux
- NVIDIA GPU
- PyTorch 1.12+
- CUDA 11.6+
For AMD cards, see additional prerequisites below.
## Usage
We expose several levels of interface with the Mamba model.
### Selective SSM
Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2).
Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py).
### Mamba Block
The main module of this repository is the Mamba architecture block wrapping the selective SSM.
Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py).
Usage:
``` python
import torch
from mamba_ssm import Mamba
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
```
### Mamba-2
The Mamba-2 block is implemented at [modules/mamba2.py](mamba_ssm/modules/mamba2.py).
A simpler version is at [modules/mamba2_simple.py](mamba_ssm/modules/mamba2_simple.py)
The usage is similar to Mamba(-1):
``` python
from mamba_ssm import Mamba2
model = Mamba2(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor, typically 64 or 128
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
```
#### SSD
A minimal version of the inner SSD module (Listing 1 from the Mamba-2 paper) with conversion between "discrete" and "continuous" SSM versions
is at [modules/ssd_minimal.py](mamba_ssm/modules/ssd_minimal.py).
### Mamba Language Model
Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.
Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py).
This is an example of how to integrate Mamba into an end-to-end neural network.
This example is used in the generation scripts below.
## Pretrained Models
Pretrained models are uploaded to
[Hugging Face](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`,
`mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, `mamba2-130m`, `mamba2-370m`,
`mamba2-780m`, `mamba2-1.3b`, `mamba2-2.7b`, `transformerpp-2.7b`, `mamba2attn-2.7b`, trained on 300B tokens on the Pile, as well as `mamba-2.8b-slimpj`
(trained on 600B tokens on the SlimPajama dataset).
The models will be autodownloaded by the generation script below.
These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models:
| Parameters | Layers | Model dim. |
|------------|--------|------------|
| 130M | 24 | 768 |
| 370M | 48 | 1024 |
| 790M | 48 | 1536 |
| 1.4B | 48 | 2048 |
| 2.8B | 64 | 2560 |
(The layer count of Mamba doubles that of a Transformer with similar size, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.)
Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.).
Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.
## Evaluations
To run zero-shot evaluations of models (corresponding to Table 3 of the paper),
we use the
[lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)
library.
1. Install `lm-evaluation-harness` by `pip install lm-eval==0.4.2`.
2. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo):
``` sh
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
```
To reproduce the results on the `mamba-2.8b-slimpj` model reported in the blogposts:
``` sh
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256
```
To run evaluations on Mamba-2 models, simply replace the model names:
``` sh
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
```
Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.
## Inference
The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py)
1. autoloads a model from the Hugging Face Hub,
2. generates completions of a user-specified prompt,
3. benchmarks the inference speed of this generation.
Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature.
### Examples
To test generation latency (e.g. batch size = 1) with different sampling strategies:
``` sh
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2
```
To test generation throughput with random prompts (e.g. large batch size):
``` sh
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 64
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 64
```
With Mamba-2, you just need to change the model name:
``` sh
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-2.7b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
```
## Troubleshooting
### Precision
Our models were trained using PyTorch [AMP](https://pytorch.org/docs/stable/amp.html) for mixed precision. AMP keeps model parameters in float32 and casts to half precision when necessary.
On the other hand, other frameworks like DeepSpeed store parameters in float16 and upcasts when necessary (e.g. for optimizer accumulation).
We've observed that higher precision for the main model parameters may be necessary, because SSMs are sensitive to their recurrent dynamics. If you are experiencing instabilities,
as a first step please try a framework storing parameters in fp32 (such as AMP).
### Initialization
Some parts of the model have initializations inherited from prior work on S4 models.
For [example](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L102), the $\Delta$ parameter has a targeted range by initializing the bias of its linear projection.
However, some frameworks may have post-initialization hooks (e.g. setting all bias terms in `nn.Linear` modules to zero).
If this is the case, you may have to add custom logic (e.g. this [line](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L104) turns off re-initializing in our trainer, but would be a no-op in any other framework)
that is specific to the training framework.
## Additional Prerequisites for AMD cards
### Patching ROCm
If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
```bash
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
```
## Citation
If you use this codebase, or otherwise find our work valuable, please cite Mamba:
```
@article{mamba,
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
author={Gu, Albert and Dao, Tri},
journal={arXiv preprint arXiv:2312.00752},
year={2023}
}
@inproceedings{mamba2,
title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality},
author={Dao, Tri and Gu, Albert},
booktitle={International Conference on Machine Learning (ICML)},
year={2024}
}
```
python benchmarks/benchmark_generation_mamba_simple.py --model-name "/public/mamba-2.2.2/AI-ModelScope/mamba-130m-hf" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name "/public/mamba-2.2.2/AI-ModelScope/mamba-130m-hf" --batch 64
\ No newline at end of file
# 模型唯一标识
modelCode=823
# 模型名称
modelName=mamba_pytorch
# 模型描述
modelDescription=基于Pytorch框架的mamba模型
# 应用场景
appScenario=训练,推理,文本生成,互联网
# 框架类型
frameType=Pytorch
absl-py==2.1.0
accelerate==0.31.0
addict==2.4.0
aiohttp==3.9.5
aiosignal==1.3.1
aitemplate @ http://10.6.10.68:8000/release/aitemplate/dtk24.04.1/aitemplate-0.0.1%2Bdas1.1.git5d8aa20.dtk2404.torch2.1.0-py3-none-any.whl#sha256=ad763a7cfd3935857cf10a07a2a97899fd64dda481add2f48de8b8930bd341dd
aliyun-python-sdk-core==2.15.1
aliyun-python-sdk-kms==2.16.3
annotated-types==0.7.0
anyio==4.4.0
apex @ http://10.6.10.68:8000/release/apex/dtk24.04.1/apex-1.1.0%2Bdas1.1.gitf477a3a.abi1.dtk2404.torch2.1.0-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=85eb662d13d6e6c3b61c2d878378c2338c4479bc03a1912c3eabddc2d9d08aa1
async-timeout==4.0.3
attrs==23.2.0
bitsandbytes @ http://10.6.10.68:8000/release/bitsandbyte/dtk24.04.1/bitsandbytes-0.42.0%2Bdas1.1.gitce85679.abi1.dtk2404.torch2.1.0-py3-none-any.whl#sha256=6324e330c8d12b858d39f4986c0ed0836fcb05f539cee92a7cf558e17954ae0d
certifi==2024.6.2
cffi==1.16.0
chardet==5.2.0
charset-normalizer==3.3.2
click==8.1.7
colorama==0.4.6
coloredlogs==15.0.1
contourpy==1.2.1
crcmod==1.7
cryptography==43.0.0
cycler==0.12.1
DataProperty==1.0.1
datasets==2.18.0
deepspeed @ http://10.6.10.68:8000/release/deepspeed/dtk24.04.1/deepspeed-0.12.3%2Bgita724046.abi1.dtk2404.torch2.1.0-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=2c158ed2dab21f4f09e7fc29776cb43a1593b13cec33168ce3483f318b852fc9
dill==0.3.8
dnspython==2.6.1
dropout-layer-norm @ http://10.6.10.68:8000/release/flash_attn/dtk24.04.1/dropout_layer_norm-0.1%2Bdas1.1gitc7a8c18.abi1.dtk2404.torch2.1-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=ae10c7cc231a8e38492292e91e76ba710d7679762604c0a7f10964b2385cdbd7
einops==0.8.0
email_validator==2.1.1
evaluate==0.4.2
exceptiongroup==1.2.1
fastapi==0.111.0
fastapi-cli==0.0.4
fastpt @ http://10.6.10.68:8000/release/fastpt/dtk24.04.1/fastpt-1.0.0%2Bdas1.1.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=ecf30dadcd2482adb1107991edde19b6559b8237379dbb0a3e6eb7306aad3f9a
filelock==3.15.1
fire==0.6.0
flash-attn @ http://10.6.10.68:8000/release/flash_attn/dtk24.04.1/flash_attn-2.0.4%2Bdas1.1gitc7a8c18.abi1.dtk2404.torch2.1-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=7ca8e78ee0624b1ff0e91e9fc265e61b9510f02123a010ac71a2f8e5d08a62f7
flatbuffers==24.3.25
fonttools==4.53.0
frozenlist==1.4.1
fsspec==2024.2.0
fused-dense-lib @ http://10.6.10.68:8000/release/flash_attn/dtk24.04.1/fused_dense_lib-0.1%2Bdas1.1gitc7a8c18.abi1.dtk2404.torch2.1-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=7202dd258a86bb7a1572e3b44b90dae667b0c948bf0f420b05924a107aaaba03
h11==0.14.0
hjson==3.1.0
httpcore==1.0.5
httptools==0.6.1
httpx==0.27.0
huggingface-hub==0.23.4
humanfriendly==10.0
hypothesis==5.35.1
idna==3.7
importlib_metadata==7.1.0
Jinja2==3.1.4
jmespath==0.10.0
joblib==1.4.2
jsonlines==4.0.0
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
kiwisolver==1.4.5
layer-check-pt @ http://10.6.10.68:8000/release/layercheck/dtk24.04.1/layer_check_pt-1.2.3.git59a087a.abi1.dtk2404.torch2.1.0-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=807adae2d4d4b74898777f81e1b94f1af4d881afe6a7826c7c910b211accbea7
lightop @ http://10.6.10.68:8000/release/lightop/dtk24.04.1/lightop-0.4%2Bdas1.1git8e60f07.abi1.dtk2404.torch2.1-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=2f2c88fd3fe4be179f44c4849e9224cb5b2b259843fc5a2d088e468b7a14c1b1
lm_eval==0.4.3
lmdeploy @ http://10.6.10.68:8000/release/lmdeploy/dtk24.04.1/lmdeploy-0.2.6%2Bdas1.1.git6ba90df.abi1.dtk2404.torch2.1.0-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=92ecee2c8b982f86e5c3219ded24d2ede219f415bf2cd4297f989a03387a203c
lxml==5.2.2
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.0
mbstrdecoder==1.1.3
mdurl==0.1.2
mmcv @ http://10.6.10.68:8000/release/mmcv/dtk24.04.1/mmcv-2.0.1%2Bdas1.1.gite58da25.abi1.dtk2404.torch2.1.0-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=7a937ae22f81b44d9100907e11303c31bf9a670cb4c92e361675674a41a8a07f
mmengine==0.10.4
mmengine-lite==0.10.4
modelscope==1.16.1
more-itertools==10.3.0
mpmath==1.3.0
msgpack==1.0.8
multidict==6.0.5
multiprocess==0.70.16
networkx==3.3
ninja==1.11.1.1
nltk==3.8.1
numexpr==2.10.1
numpy==1.24.3
onnxruntime @ http://10.6.10.68:8000/release/onnxruntime/dtk24.04.1/onnxruntime-1.15.0%2Bdas1.1.git739f24d.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=d0d24167188d2c85f1ed4110fc43e62ea40c74280716d9b5fe9540256f17869a
opencv-python==4.10.0.82
orjson==3.10.5
oss2==2.18.6
packaging==24.1
pandas==2.2.2
pathvalidate==3.2.0
peft==0.9.0
pillow==10.3.0
platformdirs==4.2.2
portalocker==2.10.1
prometheus_client==0.20.0
protobuf==5.27.1
psutil==5.9.8
py-cpuinfo==9.0.0
pyarrow==17.0.0
pyarrow-hotfix==0.6
pybind11==2.13.1
pycparser==2.22
pycryptodome==3.20.0
pydantic==2.7.4
pydantic_core==2.18.4
Pygments==2.18.0
pynvml==11.5.0
pyparsing==3.1.2
pytablewriter==1.2.0
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-multipart==0.0.9
pytz==2024.1
PyYAML==6.0.1
ray==2.9.1
referencing==0.35.1
regex==2024.5.15
requests==2.32.3
rich==13.7.1
rotary-emb @ http://10.6.10.68:8000/release/flash_attn/dtk24.04.1/rotary_emb-0.1%2Bdas1.1gitc7a8c18.abi1.dtk2404.torch2.1-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=cc15ec6ae73875515243d7f5c96ab214455a33a4a99eb7f1327f773cae1e6721
rouge-score==0.1.2
rpds-py==0.18.1
sacrebleu==2.4.2
safetensors==0.4.3
scikit-learn==1.5.1
scipy==1.13.1
sentencepiece==0.2.0
shellingham==1.5.4
shortuuid==1.0.13
six==1.16.0
sniffio==1.3.1
sortedcontainers==2.4.0
sqlitedict==2.1.0
starlette==0.37.2
sympy==1.12.1
tabledata==1.3.3
tabulate==0.9.0
tcolorpy==0.1.6
termcolor==2.4.0
threadpoolctl==3.5.0
tiktoken==0.7.0
tokenizers==0.15.0
tomli==2.0.1
torch @ http://10.6.10.68:8000/release/pytorch/dtk24.04.1/torch-2.1.0%2Bdas1.1.git3ac1bdd.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=5fd3bcef3aa197c0922727913aca53db9ce3f2fd4a9b22bba1973c3d526377f9
torchaudio @ http://10.6.10.68:8000/release/torchaudio/dtk24.04.1/torchaudio-2.1.2%2Bdas1.1.git63d9a68.abi1.dtk2404.torch2.1.0-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=4fcc556a7a2fffe64ddd57f22e5972b1b2b723f6fdfdaa305bd01551036df38b
torchvision @ http://10.6.10.68:8000/release/vision/dtk24.04.1/torchvision-0.16.0%2Bdas1.1.git7d45932.abi1.dtk2404.torch2.1-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=e3032e1bcc0857b54391d66744f97e5cff0dc7e7bb508196356ee927fb81ec01
tqdm==4.66.4
tqdm-multiprocess==0.0.11
transformers==4.39.0
triton @ http://10.6.10.68:8000/release/triton/dtk24.04.1/triton-2.1.0%2Bdas1.1.git4bf1007a.abi1.dtk2404.torch2.1.0-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=4c30d45dab071e65d1704a5cd189b14c4ac20bd59a7061032dfd631b1fc37645
typepy==1.3.2
typer==0.12.3
typing_extensions==4.12.2
tzdata==2024.1
ujson==5.10.0
urllib3==2.2.1
uvicorn==0.30.1
uvloop==0.19.0
vllm @ http://10.6.10.68:8000/release/vllm/dtk24.04.1/vllm-0.3.3%2Bdas1.1.gitdf6349c.abi1.dtk2404.torch2.1.0-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=48d265b07efa36f028eca45a3667fa10d3cf30eb1b8f019b62e3b255fb9e49c4
watchfiles==0.22.0
websockets==12.0
word2number==1.1
xentropy-cuda-lib @ http://10.6.10.68:8000/release/flash_attn/dtk24.04.1/xentropy_cuda_lib-0.1%2Bdas1.1gitc7a8c18.abi1.dtk2404.torch2.1-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=91b058d6a5fd2734a5085d68e08d3a1f948fe9c0119c46885d19f55293e2cce4
xformers @ http://10.6.10.68:8000/release/xformers/dtk24.04.1/xformers-0.0.25%2Bdas1.1.git8ef8bc1.abi1.dtk2404.torch2.1.0-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=ca87fd065753c1be3b9fad552eba02d30cd3f4c673f01e81a763834eb5cbb9cc
xxhash==3.4.1
yapf==0.40.2
yarl==1.9.4
zipp==3.19.2
zstandard==0.23.0
# lm_eval --model mamba_ssm --model_args pretrained=/public/model/AI-ModelScope/mamba-130m-hf --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 64
python evals/lm_harness_eval.py --model hf --model_args pretrained=/public/model/AI-ModelScope/mamba-130m-hf --tasks hellaswag --device cuda --batch_size 64
\ No newline at end of file
python3 train-demp.py
\ No newline at end of file
# Copyright (c) 2023, Albert Gu, Tri Dao.
import sys
import warnings
import os
import re
import ast
from pathlib import Path
from packaging.version import parse, Version
import platform
import shutil
from setuptools import setup, find_packages
import subprocess
import urllib.request
import urllib.error
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
import torch
from torch.utils.cpp_extension import (
BuildExtension,
CppExtension,
CUDAExtension,
CUDA_HOME,
HIP_HOME
)
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
PACKAGE_NAME = "mamba_ssm"
BASE_WHEEL_URL = "https://github.com/state-spaces/mamba/releases/download/{tag_name}/{wheel_name}"
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE"
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE"
def get_platform():
"""
Returns the platform name as used in wheel filenames.
"""
if sys.platform.startswith("linux"):
return "linux_x86_64"
elif sys.platform == "darwin":
mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
return f"macosx_{mac_version}_x86_64"
elif sys.platform == "win32":
return "win_amd64"
else:
raise ValueError("Unsupported platform: {}".format(sys.platform))
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
)
output = raw_output.split()
release_idx = output.index("release") + 1
bare_metal_ver = parse(output[release_idx].split(",")[0])
return raw_output, bare_metal_ver
def get_hip_version(rocm_dir):
hipcc_bin = "hipcc" if rocm_dir is None else os.path.join(rocm_dir, "bin", "hipcc")
try:
raw_output = subprocess.check_output(
[hipcc_bin, "--version"], universal_newlines=True
)
except Exception as e:
print(
f"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}"
)
return None, None
for line in raw_output.split("\n"):
if "HIP version" in line:
rocm_version = parse(line.split()[-1].rstrip('-').replace('-', '+')) # local version is not parsed correctly
return line, rocm_version
return None, None
def get_torch_hip_version():
if torch.version.hip:
return parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+'))
else:
return None
def check_if_hip_home_none(global_option: str) -> None:
if HIP_HOME is not None:
return
# warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
# in that case.
warnings.warn(
f"{global_option} was requested, but hipcc was not found. Are you sure your environment has hipcc available?"
)
def check_if_cuda_home_none(global_option: str) -> None:
if CUDA_HOME is not None:
return
# warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
# in that case.
warnings.warn(
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
)
def append_nvcc_threads(nvcc_extra_args):
return nvcc_extra_args + ["--threads", "4"]
cmdclass = {}
ext_modules = []
HIP_BUILD = bool(torch.version.hip)
if not SKIP_CUDA_BUILD:
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
cc_flag = []
if HIP_BUILD:
check_if_hip_home_none(PACKAGE_NAME)
rocm_home = os.getenv("ROCM_PATH")
_, hip_version = get_hip_version(rocm_home)
if HIP_HOME is not None:
if hip_version < Version("6.0"):
raise RuntimeError(
f"{PACKAGE_NAME} is only supported on ROCm 6.0 and above. "
"Note: make sure HIP has a supported version by running hipcc --version."
)
if hip_version == Version("6.0"):
warnings.warn(
f"{PACKAGE_NAME} requires a patch to be applied when running on ROCm 6.0. "
"Refer to the README.md for detailed instructions.",
UserWarning
)
cc_flag.append("-DBUILD_PYTHON_PACKAGE")
else:
check_if_cuda_home_none(PACKAGE_NAME)
# Check, if CUDA11 is installed for compute capability 8.0
if CUDA_HOME is not None:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.6"):
raise RuntimeError(
f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. "
"Note: make sure nvcc has a supported version by running nvcc -V."
)
cc_flag.append("-gencode")
cc_flag.append("arch=compute_53,code=sm_53")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_62,code=sm_62")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_70,code=sm_70")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_72,code=sm_72")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_87,code=sm_87")
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
if FORCE_CXX11_ABI:
torch._C._GLIBCXX_USE_CXX11_ABI = True
if HIP_BUILD:
extra_compile_args = {
"cxx": ["-O3", "-std=c++17"],
"nvcc": [
"-O3",
"-std=c++17",
f"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-fgpu-flush-denormals-to-zero",
]
+ cc_flag,
}
else:
extra_compile_args = {
"cxx": ["-O3", "-std=c++17"],
"nvcc": append_nvcc_threads(
[
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=-v",
"-lineinfo",
]
+ cc_flag
),
}
ext_modules.append(
CUDAExtension(
name="selective_scan_cuda",
sources=[
"csrc/selective_scan/selective_scan.cpp",
"csrc/selective_scan/selective_scan_fwd_fp32.cu",
"csrc/selective_scan/selective_scan_fwd_fp16.cu",
"csrc/selective_scan/selective_scan_fwd_bf16.cu",
"csrc/selective_scan/selective_scan_bwd_fp32_real.cu",
"csrc/selective_scan/selective_scan_bwd_fp32_complex.cu",
"csrc/selective_scan/selective_scan_bwd_fp16_real.cu",
"csrc/selective_scan/selective_scan_bwd_fp16_complex.cu",
"csrc/selective_scan/selective_scan_bwd_bf16_real.cu",
"csrc/selective_scan/selective_scan_bwd_bf16_complex.cu",
],
extra_compile_args=extra_compile_args,
include_dirs=[Path(this_dir) / "csrc" / "selective_scan"],
)
)
def get_package_version():
with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f:
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
public_version = ast.literal_eval(version_match.group(1))
local_version = os.environ.get("MAMBA_LOCAL_VERSION")
if local_version:
return f"{public_version}+{local_version}"
else:
return str(public_version)
def get_wheel_url():
# Determine the version numbers that will be used to determine the correct wheel
torch_version_raw = parse(torch.__version__)
if HIP_BUILD:
# We're using the HIP version used to build torch, not the one currently installed
torch_hip_version = get_torch_hip_version()
hip_ver = f"{torch_hip_version.major}{torch_hip_version.minor}"
else:
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version = parse(torch.version.cuda)
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
# to save CI time. Minor versions should be compatible.
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
gpu_compute_version = hip_ver if HIP_BUILD else cuda_version
cuda_or_hip = "hip" if HIP_BUILD else "cu"
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform()
mamba_ssm_version = get_package_version()
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f"{PACKAGE_NAME}-{mamba_ssm_version}+{cuda_or_hip}{gpu_compute_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
wheel_url = BASE_WHEEL_URL.format(
tag_name=f"v{mamba_ssm_version}", wheel_name=wheel_filename
)
return wheel_url, wheel_filename
class CachedWheelsCommand(_bdist_wheel):
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
find an existing wheel (which is currently the case for all installs). We use
the environment parameters to detect whether there is already a pre-built version of a compatible
wheel available and short-circuits the standard full build pipeline.
"""
def run(self):
if FORCE_BUILD:
return super().run()
wheel_url, wheel_filename = get_wheel_url()
print("Guessing wheel URL: ", wheel_url)
try:
urllib.request.urlretrieve(wheel_url, wheel_filename)
# Make the archive
# Lifted from the root wheel processing command
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
if not os.path.exists(self.dist_dir):
os.makedirs(self.dist_dir)
impl_tag, abi_tag, plat_tag = self.get_tag()
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
print("Raw wheel path", wheel_path)
shutil.move(wheel_filename, wheel_path)
except urllib.error.HTTPError:
print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source
super().run()
setup(
name=PACKAGE_NAME,
version=get_package_version(),
packages=find_packages(
exclude=(
"build",
"csrc",
"include",
"tests",
"dist",
"docs",
"benchmarks",
"mamba_ssm.egg-info",
)
),
author="Tri Dao, Albert Gu",
author_email="tri@tridao.me, agu@cs.cmu.edu",
description="Mamba state-space model",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/state-spaces/mamba",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License",
"Operating System :: Unix",
],
ext_modules=ext_modules,
cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
if ext_modules
else {
"bdist_wheel": CachedWheelsCommand,
},
python_requires=">=3.8",
install_requires=[
"torch",
"packaging",
"ninja",
"einops",
"triton",
"transformers",
# "causal_conv1d>=1.4.0",
],
)
import torch
from mamba_ssm import Mamba
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
图片4.jpg

46.9 KB

图片5.jpg

26.5 KB

Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment