Commit 25991f98 authored by hepj's avatar hepj
Browse files

修改readme

parent ac192496
Pipeline #1415 failed with stages
in 0 seconds
......@@ -20,32 +20,6 @@ https://arxiv.org/abs/2312.00752
原始的SSM通过A B C H矩阵的变化得到y,但是由于不同的输入X都对应相同的A B C H矩阵,因此MAMBA模型对x进行处理得到不同的A' B' C' H'矩阵来作为mamba模块
```
依赖库
pip install wheel,transformers==4.39 modelscope oss2==2.18.6 datasets==2.18.0 lm-eval==0.4.3 datasets==2.18.0
```
/usr/local/lib/python3.10/site-packages/lm_eval/api/metrics.py", line 168
exact_match = hf_evaluate.load("exact_match")改为:
改成hepj/evaluate-0.4.2/metrics/exact_match/exact_match.py
exact_match = hf_evaluate.load("hepj/evaluate-0.4.2/metrics/exact_match/exact_match.py")
然后好像可以下载到数据集了???
缓存地址为/root/.cache/huggingface/modules/datasets_modules/datasets/hellaswag/362ac471216900f3f7c021863caac4eb7886347d0f76d90b6b4361f59ffea4d7
下载好的缓存路径./.cache/huggingface/modules/datasets_modules/datasets/hellaswag.lock
./.cache/huggingface/modules/datasets_modules/datasets/hellaswag
./.cache/huggingface/modules/datasets_modules/datasets/hellaswag/362ac471216900f3f7c021863caac4eb7886347d0f76d90b6b4361f59ffea4d7/__pycache__/hellaswag.cpython-310.pyc
./.cache/huggingface/modules/datasets_modules/datasets/hellaswag/362ac471216900f3f7c021863caac4eb7886347d0f76d90b6b4361f59ffea4d7/hellaswag.py
./.cache/huggingface/modules/datasets_modules/datasets/hellaswag/362ac471216900f3f7c021863caac4eb7886347d0f76d90b6b4361f59ffea4d7/hellaswag.json
./.cache/huggingface/datasets/hellaswag
./.cache/huggingface/datasets/hellaswag/default/0.1.0/362ac471216900f3f7c021863caac4eb7886347d0f76d90b6b4361f59ffea4d7/hellaswag-test.arrow
./.cache/huggingface/datasets/hellaswag/default/0.1.0/362ac471216900f3f7c021863caac4eb7886347d0f76d90b6b4361f59ffea4d7/hellaswag-validation.arrow
./.cache/huggingface/datasets/hellaswag/default/0.1.0/362ac471216900f3f7c021863caac4eb7886347d0f76d90b6b4361f59ffea4d7/hellaswag-train.arrow
```
## 环境配置
### Docker(方式一)
......
# Copyright (c) 2023, Tri Dao, Albert Gu.
import argparse
import time
import json
import torch
import torch.nn.functional as F
from einops import rearrange
from transformers import AutoTokenizer, AutoModelForCausalLM
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
parser = argparse.ArgumentParser(description="Generation benchmarking")
parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m")
parser.add_argument("--prompt", type=str, default=None)
parser.add_argument("--promptlen", type=int, default=100)
parser.add_argument("--genlen", type=int, default=100)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--topk", type=int, default=1)
parser.add_argument("--topp", type=float, default=1.0)
parser.add_argument("--minp", type=float, default=0.0)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--batch", type=int, default=1)
args = parser.parse_args()
repeats = 3
device = "cuda"
dtype = torch.float16
print(f"Loading model {args.model_name}")
is_mamba = args.model_name.startswith("state-spaces/mamba") or args.model_name.startswith("state-spaces/transformerpp")
if is_mamba:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype)
model.eval()
print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
torch.random.manual_seed(0)
if args.prompt is None:
input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda")
attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
else:
tokens = tokenizer(args.prompt, return_tensors="pt")
input_ids = tokens.input_ids.to(device=device)
attn_mask = tokens.attention_mask.to(device=device)
max_length = input_ids.shape[1] + args.genlen
if is_mamba:
fn = lambda: model.generate(
input_ids=input_ids,
max_length=max_length,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=False,
temperature=args.temperature,
top_k=args.topk,
top_p=args.topp,
min_p=args.minp,
repetition_penalty=args.repetition_penalty,
)
else:
fn = lambda: model.generate(
input_ids=input_ids,
attention_mask=attn_mask,
max_length=max_length,
return_dict_in_generate=True,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=args.temperature,
top_k=args.topk,
top_p=args.topp,
repetition_penalty=args.repetition_penalty,
)
out = fn()
if args.prompt is not None:
print(tokenizer.batch_decode(out.sequences.tolist()))
torch.cuda.synchronize()
start = time.time()
for _ in range(repeats):
fn()
torch.cuda.synchronize()
print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}")
print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms")
# This workflow will:
# - Create a new Github release
# - Build wheels for supported architectures
# - Deploy the wheels to the Github release
# - Release the static code to PyPi
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Build wheels and deploy
on:
create:
tags:
- v*
jobs:
setup_release:
name: Create Release
runs-on: ubuntu-latest
steps:
- name: Get the tag version
id: extract_branch
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
shell: bash
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ steps.extract_branch.outputs.branch }}
release_name: ${{ steps.extract_branch.outputs.branch }}
build_wheels:
name: Build Wheel
needs: setup_release
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-20.04]
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
torch-version: ['2.0.1', '2.1.2', '2.2.2', '2.3.1', '2.4.0.dev20240505']
cuda-version: ['11.8.0', '12.2.2']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
# when building without C++11 ABI and using it on nvcr images.
cxx11_abi: ['FALSE', 'TRUE']
exclude:
# Pytorch < 2.2 does not support Python 3.12
- torch-version: '2.0.1'
python-version: '3.12'
- torch-version: '2.1.2'
python-version: '3.12'
# Pytorch <= 2.0 only supports CUDA <= 11.8
- torch-version: '2.0.1'
cuda-version: '12.2.2'
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Set CUDA and PyTorch versions
run: |
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
- name: Free up disk space
if: ${{ runner.os == 'Linux' }}
# https://github.com/easimon/maximize-build-space/blob/master/action.yml
# https://github.com/easimon/maximize-build-space/tree/test-report
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
- name: Set up swap space
if: runner.os == 'Linux'
uses: pierotofy/set-swap-space@v1.0
with:
swap-size-gb: 10
- name: Install CUDA ${{ matrix.cuda-version }}
if: ${{ matrix.cuda-version != 'cpu' }}
uses: Jimver/cuda-toolkit@v0.2.14
id: cuda-toolkit
with:
cuda: ${{ matrix.cuda-version }}
linux-local-args: '["--toolkit"]'
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
# method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
method: 'network'
# We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
# not just nvcc
# sub-packages: '["nvcc"]'
- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
run: |
pip install --upgrade pip
# If we don't install before installing Pytorch, we get error for torch 2.0.1
# ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
pip install lit
# For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
pip install setuptools==68.0.0
# We want to figure out the CUDA version to download pytorch
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
# This code is ugly, maybe there's a better way to do this.
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
minv = {'2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118}[env['MATRIX_TORCH_VERSION']]; \
maxv = {'2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 121}[env['MATRIX_TORCH_VERSION']]; \
print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
)
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
else
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
fi
nvcc --version
python --version
python -c "import torch; print('PyTorch:', torch.__version__)"
python -c "import torch; print('CUDA:', torch.version.cuda)"
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
shell:
bash
- name: Build wheel
run: |
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
# However this still fails so I'm using a newer version of setuptools
pip install setuptools==68.0.0
pip install ninja packaging wheel
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Limit MAX_JOBS otherwise the github runner goes OOM
MAX_JOBS=2 CAUSAL_CONV1D_FORCE_BUILD="TRUE" CAUSAL_CONV1D_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
- name: Log Built Wheels
run: |
ls dist
- name: Get the tag version
id: extract_branch
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
- name: Get Release with tag
id: get_current_release
uses: joutvhu/get-release@v1
with:
tag_name: ${{ steps.extract_branch.outputs.branch }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload Release Asset
id: upload_release_asset
uses: actions/upload-release-asset@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
asset_path: ./dist/${{env.wheel_name}}
asset_name: ${{env.wheel_name}}
asset_content_type: application/*
publish_package:
name: Publish package
needs: [build_wheels]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies
run: |
pip install ninja packaging setuptools wheel twine
# We don't want to download anything CUDA-related here
pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Build core package
env:
CAUSAL_CONV1D_SKIP_CUDA_BUILD: "TRUE"
run: |
python setup.py sdist --dist-dir=dist
- name: Deploy
env:
TWINE_USERNAME: "__token__"
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
python -m twine upload dist/*
*__pycache__/
*.egg-info/
build/
**.so
*.hip
*_hip.*
\ No newline at end of file
Tri Dao, tri@tridao.me
BSD 3-Clause License
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Causal depthwise conv1d in CUDA with a PyTorch interface
Features:
- Support fp32, fp16, bf16.
- Kernel size 2, 3, 4.
## How to use
```
from causal_conv1d import causal_conv1d_fn
```
```
def causal_conv1d_fn(x, weight, bias=None, activation=None):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
activation: either None or "silu" or "swish"
out: (batch, dim, seqlen)
"""
```
Equivalent to:
```
import torch.nn.functional as F
F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
```
## Additional Prerequisites for AMD cards
### Patching ROCm
If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
```bash
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
```
__version__ = "1.4.0"
from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
# Copyright (c) 2024, Tri Dao.
import torch
import torch.nn.functional as F
import causal_conv1d_cuda
class CausalConv1dFn(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x,
weight,
bias=None,
seq_idx=None,
initial_states=None,
return_final_states=False,
final_states_out=None,
activation=None,
):
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(2) != 1 and x.stride(1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
if seq_idx is not None:
assert (
initial_states is None
), "initial_states must be None if seq_idx is not None"
assert (
not return_final_states
), "If seq_idx is not None, we don't return final_states_out"
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
if initial_states is not None and (
initial_states.stride(2) != 1 and initial_states.stride(1) != 1
):
initial_states = initial_states.contiguous()
if return_final_states:
assert (
x.stride(1) == 1
), "Only channel-last layout support returning final_states_out"
if final_states_out is not None:
assert (
final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
)
else:
batch, dim, seqlen = x.shape
width = weight.shape[1]
final_states_out = torch.empty(
batch, width - 1, dim, device=x.device, dtype=x.dtype
).transpose(1, 2)
else:
final_states_out = None
ctx.activation = activation in ["silu", "swish"]
out = causal_conv1d_cuda.causal_conv1d_fwd(
x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
)
ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
ctx.return_final_states = return_final_states
ctx.return_dinitial_states = (
initial_states is not None and initial_states.requires_grad
)
return out if not return_final_states else (out, final_states_out)
@staticmethod
def backward(ctx, dout, *args):
x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
dfinal_states = args[0] if ctx.return_final_states else None
if dout.stride(2) != 1 and dout.stride(1) != 1:
dout = dout.contiguous()
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
# backward of conv1d with the backward of chunk).
# Here we just pass in None and dx will be allocated in the C++ code.
dx, dweight, dbias, dinitial_states = causal_conv1d_cuda.causal_conv1d_bwd(
x,
weight,
bias,
dout,
seq_idx,
initial_states,
dfinal_states,
None,
ctx.return_dinitial_states,
ctx.activation,
)
return (
dx,
dweight,
dbias if bias is not None else None,
None,
dinitial_states if initial_states is not None else None,
None,
None,
None,
)
def causal_conv1d_fn(
x,
weight,
bias=None,
seq_idx=None,
initial_states=None,
return_final_states=False,
final_states_out=None,
activation=None,
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
seq_idx: (batch, seqlen)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1), to be written to
activation: either None or "silu" or "swish"
out: (batch, dim, seqlen)
"""
return CausalConv1dFn.apply(
x,
weight,
bias,
seq_idx,
initial_states,
return_final_states,
final_states_out,
activation,
)
def causal_conv1d_ref(
x,
weight,
bias=None,
initial_states=None,
return_final_states=False,
final_states_out=None,
activation=None,
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
x = x.to(weight.dtype)
seqlen = x.shape[-1]
dim, width = weight.shape
if initial_states is None:
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
else:
x = torch.cat([initial_states, x], dim=-1)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
out = out[..., :seqlen]
if return_final_states:
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
dtype_in
) # (batch, dim, width - 1)
if final_states_out is not None:
final_states_out.copy_(final_states)
else:
final_states_out = final_states
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
return out if not return_final_states else (out, final_states_out)
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state starting at the index
@cache_seqlens % state_len.
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
activation = activation in ["silu", "swish"]
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
out = causal_conv1d_cuda.causal_conv1d_update(
x, conv_state, weight, bias, activation, cache_seqlens
)
if unsqueeze:
out = out.squeeze(-1)
return out
def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state starting at the index
@cache_seqlens % state_len before performing the convolution.
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
batch, dim, seqlen = x.shape
width = weight.shape[1]
state_len = conv_state.shape[-1]
assert conv_state.shape == (batch, dim, state_len)
assert weight.shape == (dim, width)
if cache_seqlens is None:
x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
conv_state.copy_(x_new[:, :, -state_len:])
else:
width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
conv_state.scatter_(2, copy_idx, x)
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
if unsqueeze:
out = out.squeeze(-1)
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
import torch
from torch import Tensor
import triton
import triton.language as tl
@triton.jit
def _causal_conv1d_varlen_states(
X,
CU_SEQLENS,
STATES,
state_len,
dim,
stride_x_seqlen, stride_x_dim,
stride_states_batch, stride_states_seqlen, stride_states_dim,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr
):
batch_idx = tl.program_id(2)
STATES += batch_idx * stride_states_batch
end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
other=0)
rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
x,
mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
"""
Forward pass only, does not support backward pass.
Parameters:
x: (total_tokens, dim)
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
If some of those elements belong to a different sequence, the value of the states will be zero.
Return:
states: (batch, dim, state_len)
"""
_, dim = x.shape
batch = cu_seqlens.shape[0] - 1
cu_seqlens = cu_seqlens.contiguous()
states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
BLOCK_M = min(triton.next_power_of_2(state_len), 16)
BLOCK_N = min(triton.next_power_of_2(dim), 256)
grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
with torch.cuda.device(x.device.index):
_causal_conv1d_varlen_states[grid](
x,
cu_seqlens,
states,
state_len,
dim,
x.stride(0), x.stride(1),
states.stride(0), states.stride(2), states.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
)
return states
def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
"""
Forward pass only, does not support backward pass.
Parameters:
x: (total_tokens, dim)
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
If some of those elements belong to a different sequence, the value of the states will be zero.
Return:
states: (batch, dim, state_len)
"""
_, dim = x.shape
batch = cu_seqlens.shape[0] - 1
cu_seqlens = cu_seqlens.contiguous()
states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
for i in range(batch):
end_idx = cu_seqlens[i + 1]
start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
return states
This diff is collapsed.
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
struct ConvParamsBase {
using index_t = uint32_t;
int batch, dim, seqlen, width;
bool silu_activation;
index_t x_batch_stride;
index_t x_c_stride;
index_t x_l_stride;
index_t weight_c_stride;
index_t weight_width_stride;
index_t out_batch_stride;
index_t out_c_stride;
index_t out_l_stride;
int conv_state_len;
index_t conv_state_batch_stride;
index_t conv_state_c_stride;
index_t conv_state_l_stride;
// Common data pointers.
void *__restrict__ x_ptr;
void *__restrict__ weight_ptr;
void *__restrict__ bias_ptr;
void *__restrict__ out_ptr;
void *__restrict__ conv_state_ptr;
int32_t *__restrict__ cache_seqlens;
void *__restrict__ seq_idx_ptr;
// No __restrict__ since initial_states could be the same as final_states.
void * initial_states_ptr;
index_t initial_states_batch_stride;
index_t initial_states_l_stride;
index_t initial_states_c_stride;
void * final_states_ptr;
index_t final_states_batch_stride;
index_t final_states_l_stride;
index_t final_states_c_stride;
};
struct ConvParamsBwd: public ConvParamsBase {
index_t dx_batch_stride;
index_t dx_c_stride;
index_t dx_l_stride;
index_t dweight_c_stride;
index_t dweight_width_stride;
index_t dout_batch_stride;
index_t dout_c_stride;
index_t dout_l_stride;
// Common data pointers.
void *__restrict__ dx_ptr;
void *__restrict__ dweight_ptr;
void *__restrict__ dbias_ptr;
void *__restrict__ dout_ptr;
void * dinitial_states_ptr;
index_t dinitial_states_batch_stride;
index_t dinitial_states_l_stride;
index_t dinitial_states_c_stride;
void * dfinal_states_ptr;
index_t dfinal_states_batch_stride;
index_t dfinal_states_l_stride;
index_t dfinal_states_c_stride;
};
This diff is collapsed.
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#ifndef USE_ROCM
#include <cuda_bf16.h>
template<typename T>
__device__ inline T shuffle_xor(T val, int offset) {
return __shfl_xor_sync(uint32_t(-1), val, offset);
}
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
{
return std::max(ilist);
}
template<typename T>
constexpr T constexpr_min(T a, T b) {
return std::min(a, b);
}
#else
#include <hip/hip_bf16.h>
template<typename T>
__device__ inline T shuffle_xor(T val, int offset) {
return __shfl_xor(val, offset);
}
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
{
return *std::max_element(ilist.begin(), ilist.end());
}
template<typename T>
constexpr T constexpr_min(T a, T b) {
return a < b ? a : b;
}
#endif
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int BYTES> struct BytesToType {};
template<> struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<> struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<> struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<> struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<> struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
};
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, shuffle_xor(x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, shuffle_xor(x, 1));
return x;
}
};
This diff is collapsed.
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#include "causal_conv1d.h"
#include "causal_conv1d_common.h"
#include "static_switch.h"
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
struct Causal_conv1d_update_kernel_traits {
using input_t = input_t_;
using weight_t = weight_t_;
static constexpr int kNThreads = kNThreads_;
static constexpr int kWidth = kWidth_;
static constexpr int kNBytes = sizeof(input_t);
static_assert(kNBytes == 2 || kNBytes == 4);
};
template<typename Ktraits, bool kIsCircularBuffer>
__global__ __launch_bounds__(Ktraits::kNThreads)
void causal_conv1d_update_kernel(ConvParamsBase params) {
constexpr int kWidth = Ktraits::kWidth;
constexpr int kNThreads = Ktraits::kNThreads;
using input_t = typename Ktraits::input_t;
using weight_t = typename Ktraits::weight_t;
const int tidx = threadIdx.x;
const int batch_id = blockIdx.x;
const int channel_id = blockIdx.y * kNThreads + tidx;
if (channel_id >= params.dim) return;
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
+ channel_id * params.x_c_stride;
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
+ channel_id * params.conv_state_c_stride;
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
+ channel_id * params.out_c_stride;
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
int state_len = params.conv_state_len;
int advance_len = params.seqlen;
int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
int update_idx = cache_seqlen - (kWidth - 1);
update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
float weight_vals[kWidth] = {0};
#pragma unroll
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
float x_vals[kWidth] = {0};
if constexpr (!kIsCircularBuffer) {
#pragma unroll 2
for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
}
#pragma unroll
for (int i = 0; i < kWidth - 1; ++i) {
input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
}
x_vals[i] = float(state_val);
}
} else {
#pragma unroll
for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
x_vals[i] = float(state_val);
}
}
#pragma unroll 2
for (int i = 0; i < params.seqlen; ++i) {
input_t x_val = x[i * params.x_l_stride];
if constexpr (!kIsCircularBuffer) {
if (i < advance_len && state_len - advance_len + i >= 0) {
conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
}
} else {
conv_state[update_idx * params.conv_state_l_stride] = x_val;
++update_idx;
update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
}
x_vals[kWidth - 1] = float(x_val);
float out_val = bias_val;
#pragma unroll
for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
out[i * params.out_l_stride] = input_t(out_val);
// Shift the input buffer by 1
#pragma unroll
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
}
}
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
void causal_conv1d_update_launch(ConvParamsBase &params, cudaStream_t stream) {
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
auto kernel = params.cache_seqlens == nullptr
? &causal_conv1d_update_kernel<Ktraits, false>
: &causal_conv1d_update_kernel<Ktraits, true>;
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template<typename input_t, typename weight_t>
void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream) {
if (params.width == 2) {
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
} else if (params.width == 3) {
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
} else if (params.width == 4) {
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
}
}
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_update_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_update_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_update_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_update_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_update_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_update_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
\ No newline at end of file
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
static constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
static constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
--- /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h 2023-12-12 20:11:48.000000000 +0000
+++ rocm_update_files/amd_hip_bf16.h 2024-05-20 17:40:26.983349079 +0000
@@ -137,7 +137,7 @@
* \ingroup HIP_INTRINSIC_BFLOAT16_CONV
* \brief Converts float to bfloat16
*/
-__HOST_DEVICE__ __hip_bfloat16 __float2bfloat16(float f) {
+__HOST_DEVICE__ static inline __hip_bfloat16 __float2bfloat16(float f) {
__hip_bfloat16 ret;
union {
float fp32;
@@ -181,7 +181,7 @@
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
* \brief Converts and moves bfloat162 to float2
*/
-__HOST_DEVICE__ float2 __bfloat1622float2(const __hip_bfloat162 a) {
+__HOST_DEVICE__ static inline float2 __bfloat1622float2(const __hip_bfloat162 a) {
return float2{__bfloat162float(a.x), __bfloat162float(a.y)};
}
@@ -209,7 +209,7 @@
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
* \brief Convert double to __hip_bfloat16
*/
-__HOST_DEVICE__ __hip_bfloat16 __double2bfloat16(const double a) {
+__HOST_DEVICE__ static inline __hip_bfloat16 __double2bfloat16(const double a) {
return __float2bfloat16((float)a);
}
@@ -217,7 +217,7 @@
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
* \brief Convert float2 to __hip_bfloat162
*/
-__HOST_DEVICE__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) {
+__HOST_DEVICE__ static inline __hip_bfloat162 __float22bfloat162_rn(const float2 a) {
return __hip_bfloat162{__float2bfloat16(a.x), __float2bfloat16(a.y)};
}
@@ -247,7 +247,7 @@
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
* \brief Converts high 16 bits of __hip_bfloat162 to float and returns the result
*/
-__HOST_DEVICE__ float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); }
+__HOST_DEVICE__ static inline float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); }
/**
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
@@ -275,7 +275,7 @@
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
* \brief Converts low 16 bits of __hip_bfloat162 to float and returns the result
*/
-__HOST_DEVICE__ float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); }
+__HOST_DEVICE__ static inline float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); }
/**
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment