Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
922732b2
Unverified
Commit
922732b2
authored
Jul 29, 2024
by
Daniël de Kok
Committed by
GitHub
Jul 29, 2024
Browse files
Install Marlin from standalone package (#2320)
parent
583d37a2
Changes
21
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
79 additions
and
7107 deletions
+79
-7107
Dockerfile
Dockerfile
+1
-11
server/marlin/COPYRIGHT
server/marlin/COPYRIGHT
+0
-20
server/marlin/marlin_kernels/__init__.pyi
server/marlin/marlin_kernels/__init__.pyi
+0
-76
server/marlin/marlin_kernels/awq_marlin_repack.cu
server/marlin/marlin_kernels/awq_marlin_repack.cu
+0
-269
server/marlin/marlin_kernels/ext.cpp
server/marlin/marlin_kernels/ext.cpp
+0
-16
server/marlin/marlin_kernels/ext.hh
server/marlin/marlin_kernels/ext.hh
+0
-39
server/marlin/marlin_kernels/fp8_marlin.cu
server/marlin/marlin_kernels/fp8_marlin.cu
+0
-1305
server/marlin/marlin_kernels/gptq_marlin.cu
server/marlin/marlin_kernels/gptq_marlin.cu
+0
-2195
server/marlin/marlin_kernels/gptq_marlin_repack.cu
server/marlin/marlin_kernels/gptq_marlin_repack.cu
+0
-344
server/marlin/marlin_kernels/marlin.cuh
server/marlin/marlin_kernels/marlin.cuh
+0
-87
server/marlin/marlin_kernels/marlin_cuda_kernel.cu
server/marlin/marlin_kernels/marlin_cuda_kernel.cu
+0
-1138
server/marlin/marlin_kernels/marlin_dtypes.cuh
server/marlin/marlin_kernels/marlin_dtypes.cuh
+0
-79
server/marlin/marlin_kernels/py.typed
server/marlin/marlin_kernels/py.typed
+0
-0
server/marlin/marlin_kernels/sparse/common/base.h
server/marlin/marlin_kernels/sparse/common/base.h
+0
-51
server/marlin/marlin_kernels/sparse/common/mem.h
server/marlin/marlin_kernels/sparse/common/mem.h
+0
-136
server/marlin/marlin_kernels/sparse/common/mma.h
server/marlin/marlin_kernels/sparse/common/mma.h
+0
-191
server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu
server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu
+0
-1125
server/marlin/setup.py
server/marlin/setup.py
+0
-24
server/poetry.lock
server/poetry.lock
+70
-1
server/pyproject.toml
server/pyproject.toml
+8
-0
No files found.
Dockerfile
View file @
922732b2
...
@@ -140,13 +140,6 @@ COPY server/Makefile-eetq Makefile
...
@@ -140,13 +140,6 @@ COPY server/Makefile-eetq Makefile
# Build specific version of transformers
# Build specific version of transformers
RUN
TORCH_CUDA_ARCH_LIST
=
"8.0;8.6+PTX"
make build-eetq
RUN
TORCH_CUDA_ARCH_LIST
=
"8.0;8.6+PTX"
make build-eetq
# Build marlin kernels
FROM
kernel-builder AS marlin-kernels-builder
WORKDIR
/usr/src
COPY
server/marlin/ .
# Build specific version of transformers
RUN
TORCH_CUDA_ARCH_LIST
=
"8.0;8.6+PTX"
python setup.py build
# Build Lorax Punica kernels
# Build Lorax Punica kernels
FROM
kernel-builder AS lorax-punica-builder
FROM
kernel-builder AS lorax-punica-builder
WORKDIR
/usr/src
WORKDIR
/usr/src
...
@@ -231,9 +224,6 @@ COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-31
...
@@ -231,9 +224,6 @@ COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-31
COPY
--from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY
--from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from eetq kernels builder
# Copy build artifacts from eetq kernels builder
COPY
--from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY
--from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from marlin kernels builder
COPY
--from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY
--from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from fbgemm builder
# Copy build artifacts from fbgemm builder
COPY
--from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages
COPY
--from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from vllm builder
# Copy build artifacts from vllm builder
...
@@ -252,7 +242,7 @@ COPY server/Makefile server/Makefile
...
@@ -252,7 +242,7 @@ COPY server/Makefile server/Makefile
RUN
cd
server
&&
\
RUN
cd
server
&&
\
make gen-server
&&
\
make gen-server
&&
\
pip
install
-r
requirements_cuda.txt
&&
\
pip
install
-r
requirements_cuda.txt
&&
\
pip
install
".[bnb, accelerate, quantize, peft, outlines]"
--no-cache-dir
&&
\
pip
install
".[bnb, accelerate,
marlin,
quantize, peft, outlines]"
--no-cache-dir
&&
\
pip
install
nvidia-nccl-cu12
==
2.22.3
pip
install
nvidia-nccl-cu12
==
2.22.3
ENV
LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
ENV
LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
...
...
server/marlin/COPYRIGHT
deleted
100644 → 0
View file @
583d37a2
These kernels were vendored from VLLM. The Marlin kernels were developed
by Elias Frantar and extended by Neural Magic.
---
Copyright (C) Marlin.2024 Elias Frantar
Modified by Neural Magic
Copyright 2024 The vLLM team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
server/marlin/marlin_kernels/__init__.pyi
deleted
100644 → 0
View file @
583d37a2
import torch
def gptq_marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
size_m: int,
size_n: int,
size_k: int,
is_k_full: bool,
) -> torch.Tensor:
"""
Matrix multiplication using Marlin kernels. This is an extension of
`marlin_gemm` that supports converted GPTQ kernels.
"""
...
def gptq_marlin_24_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_meta: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
"""
Matrix multiplication using Marlin kernels. This is an extension of
`marlin_gemm` that supports 2:4 sparsity.
"""
...
def gptq_marlin_repack(
b_q_weight: torch.Tensor,
perm: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
"""Repack GPTQ parameters for Marlin kernels."""
...
def marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
"""
Matrix multiplication using Marlin kernels.
"""
...
# fp8 marlin
def fp8_marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
return torch.ops._C.fp8_marlin_gemm(
a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
)
server/marlin/marlin_kernels/awq_marlin_repack.cu
deleted
100644 → 0
View file @
583d37a2
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
awq_marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{}
}
// namespace marlin
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
}
#else
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
>
__global__
void
awq_marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{
constexpr
int
pack_factor
=
32
/
num_bits
;
int
k_tiles
=
size_k
/
tile_k_size
;
int
n_tiles
=
size_n
/
tile_n_size
;
int
block_k_tiles
=
div_ceil
(
k_tiles
,
gridDim
.
x
);
int
start_k_tile
=
blockIdx
.
x
*
block_k_tiles
;
if
(
start_k_tile
>=
k_tiles
)
{
return
;
}
int
finish_k_tile
=
min
(
start_k_tile
+
block_k_tiles
,
k_tiles
);
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait
<
repack_stages
-
2
>
();
__syncthreads
();
};
extern
__shared__
int4
sh
[];
constexpr
int
tile_n_ints
=
tile_n_size
/
pack_factor
;
constexpr
int
stage_n_threads
=
tile_n_ints
/
4
;
constexpr
int
stage_k_threads
=
tile_k_size
;
constexpr
int
stage_size
=
stage_k_threads
*
stage_n_threads
;
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
k_tile_id
,
int
n_tile_id
)
{
if
(
n_tile_id
>=
n_tiles
)
{
cp_async_fence
();
return
;
}
int
first_n
=
n_tile_id
*
tile_n_size
;
int
first_n_packed
=
first_n
/
pack_factor
;
int4
*
sh_ptr
=
sh
+
stage_size
*
pipe
;
if
(
threadIdx
.
x
<
stage_size
)
{
int
k_id
=
threadIdx
.
x
/
stage_n_threads
;
int
n_id
=
threadIdx
.
x
%
stage_n_threads
;
int
first_k
=
k_tile_id
*
tile_k_size
;
cp_async4
(
&
sh_ptr
[
k_id
*
stage_n_threads
+
n_id
],
reinterpret_cast
<
int4
const
*>
(
&
(
b_q_weight_ptr
[(
first_k
+
k_id
)
*
(
size_n
/
pack_factor
)
+
first_n_packed
+
(
n_id
*
4
)])));
}
cp_async_fence
();
};
auto
repack_tile
=
[
&
](
int
pipe
,
int
k_tile_id
,
int
n_tile_id
)
{
if
(
n_tile_id
>=
n_tiles
)
{
return
;
}
int
warp_id
=
threadIdx
.
x
/
32
;
int
th_id
=
threadIdx
.
x
%
32
;
if
(
warp_id
>=
4
)
{
return
;
}
int
tc_col
=
th_id
/
4
;
int
tc_row
=
(
th_id
%
4
)
*
2
;
constexpr
int
tc_offsets
[
4
]
=
{
0
,
1
,
8
,
9
};
int
cur_n
=
warp_id
*
16
+
tc_col
;
int
cur_n_packed
=
cur_n
/
pack_factor
;
int
cur_n_pos
=
cur_n
%
pack_factor
;
constexpr
int
sh_stride
=
tile_n_ints
;
constexpr
uint32_t
mask
=
(
1
<<
num_bits
)
-
1
;
int4
*
sh_stage_ptr
=
sh
+
stage_size
*
pipe
;
uint32_t
*
sh_stage_int_ptr
=
reinterpret_cast
<
uint32_t
*>
(
sh_stage_ptr
);
// Undo interleaving
int
cur_n_pos_unpacked
;
if
constexpr
(
num_bits
==
4
)
{
constexpr
int
undo_pack
[
8
]
=
{
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
};
cur_n_pos_unpacked
=
undo_pack
[
cur_n_pos
];
}
else
{
constexpr
int
undo_pack
[
4
]
=
{
0
,
2
,
1
,
3
};
cur_n_pos_unpacked
=
undo_pack
[
cur_n_pos
];
}
uint32_t
vals
[
8
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
cur_elem
=
tc_row
+
tc_offsets
[
i
];
int
packed_src_0
=
sh_stage_int_ptr
[
cur_n_packed
+
sh_stride
*
cur_elem
];
int
packed_src_1
=
sh_stage_int_ptr
[
cur_n_packed
+
(
8
/
pack_factor
)
+
sh_stride
*
cur_elem
];
vals
[
i
]
=
(
packed_src_0
>>
(
cur_n_pos_unpacked
*
num_bits
))
&
mask
;
vals
[
4
+
i
]
=
(
packed_src_1
>>
(
cur_n_pos_unpacked
*
num_bits
))
&
mask
;
}
constexpr
int
tile_size
=
tile_k_size
*
tile_n_size
/
pack_factor
;
int
out_offset
=
(
k_tile_id
*
n_tiles
+
n_tile_id
)
*
tile_size
;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if
constexpr
(
num_bits
==
4
)
{
constexpr
int
pack_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
uint32_t
res
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
res
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
4
);
}
out_ptr
[
out_offset
+
th_id
*
4
+
warp_id
]
=
res
;
}
else
{
constexpr
int
pack_idx
[
4
]
=
{
0
,
2
,
1
,
3
};
uint32_t
res1
=
0
;
uint32_t
res2
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
res1
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
8
);
res2
|=
vals
[
4
+
pack_idx
[
i
]]
<<
(
i
*
8
);
}
out_ptr
[
out_offset
+
th_id
*
8
+
(
warp_id
*
2
)
+
0
]
=
res1
;
out_ptr
[
out_offset
+
th_id
*
8
+
(
warp_id
*
2
)
+
1
]
=
res2
;
}
};
auto
start_pipes
=
[
&
](
int
k_tile_id
,
int
n_tile_id
)
{
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
-
1
;
pipe
++
)
{
fetch_to_shared
(
pipe
,
k_tile_id
,
n_tile_id
+
pipe
);
}
wait_for_stage
();
};
#pragma unroll
for
(
int
k_tile_id
=
start_k_tile
;
k_tile_id
<
finish_k_tile
;
k_tile_id
++
)
{
int
n_tile_id
=
0
;
start_pipes
(
k_tile_id
,
n_tile_id
);
while
(
n_tile_id
<
n_tiles
)
{
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
;
pipe
++
)
{
fetch_to_shared
((
pipe
+
repack_stages
-
1
)
%
repack_stages
,
k_tile_id
,
n_tile_id
+
pipe
+
repack_stages
-
1
);
repack_tile
(
pipe
,
k_tile_id
,
n_tile_id
+
pipe
);
wait_for_stage
();
}
n_tile_id
+=
repack_stages
;
}
}
}
}
// namespace marlin
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
}
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK
(
size_k
%
marlin
::
tile_k_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_k_size = "
,
marlin
::
tile_k_size
);
TORCH_CHECK
(
size_n
%
marlin
::
tile_n_size
==
0
,
"size_n = "
,
size_n
,
" is not divisible by tile_n_size = "
,
marlin
::
tile_n_size
);
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
int
const
pack_factor
=
32
/
num_bits
;
// Verify B
TORCH_CHECK
(
b_q_weight
.
size
(
0
)
==
size_k
,
"b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
" is not size_k = "
,
size_k
);
TORCH_CHECK
((
size_n
/
pack_factor
)
==
b_q_weight
.
size
(
1
),
"Shape mismatch: b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
", size_n = "
,
size_n
,
", pack_factor = "
,
pack_factor
);
// Verify device and strides
TORCH_CHECK
(
b_q_weight
.
device
().
is_cuda
(),
"b_q_weight is not on GPU"
);
TORCH_CHECK
(
b_q_weight
.
is_contiguous
(),
"b_q_weight is not contiguous"
);
TORCH_CHECK
(
b_q_weight
.
dtype
()
==
at
::
kInt
,
"b_q_weight type is not kInt"
);
// Alloc buffers
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
b_q_weight
));
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
b_q_weight
.
dtype
())
.
device
(
b_q_weight
.
device
());
torch
::
Tensor
out
=
torch
::
empty
(
{
size_k
/
marlin
::
tile_size
,
size_n
*
marlin
::
tile_size
/
pack_factor
},
options
);
// Get ptrs
uint32_t
const
*
b_q_weight_ptr
=
reinterpret_cast
<
uint32_t
const
*>
(
b_q_weight
.
data_ptr
());
uint32_t
*
out_ptr
=
reinterpret_cast
<
uint32_t
*>
(
out
.
data_ptr
());
// Get dev info
int
dev
=
b_q_weight
.
get_device
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
dev
);
int
blocks
;
cudaDeviceGetAttribute
(
&
blocks
,
cudaDevAttrMultiProcessorCount
,
dev
);
int
max_shared_mem
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
if
(
false
)
{
}
CALL_IF
(
4
)
CALL_IF
(
8
)
else
{
TORCH_CHECK
(
false
,
"Unsupported repack config: num_bits = "
,
num_bits
);
}
return
out
;
}
#endif
server/marlin/marlin_kernels/ext.cpp
deleted
100644 → 0
View file @
583d37a2
#include <torch/extension.h>
#include "ext.hh"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"awq_marlin_repack"
,
&
awq_marlin_repack
,
"Repack AWQ parameters for Marlin"
);
m
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"Marlin gemm with GPTQ compatibility"
);
m
.
def
(
"gptq_marlin_24_gemm"
,
&
gptq_marlin_24_gemm
,
"Marlin sparse 2:4 gemm"
);
m
.
def
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack
,
"Repack GPTQ parameters for Marlin"
);
m
.
def
(
"marlin_gemm"
,
&
marlin_gemm
,
"Marlin gemm"
);
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
m
.
def
(
"fp8_marlin_gemm"
,
&
fp8_marlin_gemm
);
}
server/marlin/marlin_kernels/ext.hh
deleted
100644 → 0
View file @
583d37a2
#pragma once
#include <torch/library.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async
#else
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
);
torch
::
Tensor
gptq_marlin_24_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_meta
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
torch
::
Tensor
marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
torch
::
Tensor
fp8_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
#endif
server/marlin/marlin_kernels/fp8_marlin.cu
deleted
100644 → 0
View file @
583d37a2
This diff is collapsed.
Click to expand it.
server/marlin/marlin_kernels/gptq_marlin.cu
deleted
100644 → 0
View file @
583d37a2
This diff is collapsed.
Click to expand it.
server/marlin/marlin_kernels/gptq_marlin_repack.cu
deleted
100644 → 0
View file @
583d37a2
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
gptq_marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{}
}
// namespace marlin
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
}
#else
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
gptq_marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{
constexpr
int
pack_factor
=
32
/
num_bits
;
int
k_tiles
=
size_k
/
tile_k_size
;
int
n_tiles
=
size_n
/
tile_n_size
;
int
block_k_tiles
=
div_ceil
(
k_tiles
,
gridDim
.
x
);
int
start_k_tile
=
blockIdx
.
x
*
block_k_tiles
;
if
(
start_k_tile
>=
k_tiles
)
{
return
;
}
int
finish_k_tile
=
min
(
start_k_tile
+
block_k_tiles
,
k_tiles
);
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait
<
repack_stages
-
2
>
();
__syncthreads
();
};
extern
__shared__
int4
sh
[];
constexpr
int
perm_size
=
tile_k_size
/
4
;
int4
*
sh_perm_ptr
=
sh
;
int4
*
sh_pipe_ptr
=
sh_perm_ptr
;
if
constexpr
(
has_perm
)
{
sh_pipe_ptr
+=
perm_size
;
}
constexpr
int
tile_ints
=
tile_k_size
/
pack_factor
;
constexpr
int
stage_n_threads
=
tile_n_size
/
4
;
constexpr
int
stage_k_threads
=
has_perm
?
tile_k_size
:
tile_ints
;
constexpr
int
stage_size
=
stage_k_threads
*
stage_n_threads
;
auto
load_perm_to_shared
=
[
&
](
int
k_tile_id
)
{
int
first_k_int4
=
(
k_tile_id
*
tile_k_size
)
/
4
;
int4
const
*
perm_int4_ptr
=
reinterpret_cast
<
int4
const
*>
(
perm_ptr
);
if
(
threadIdx
.
x
<
perm_size
)
{
sh_perm_ptr
[
threadIdx
.
x
]
=
perm_int4_ptr
[
first_k_int4
+
threadIdx
.
x
];
}
__syncthreads
();
};
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
k_tile_id
,
int
n_tile_id
)
{
if
(
n_tile_id
>=
n_tiles
)
{
cp_async_fence
();
return
;
}
int
first_n
=
n_tile_id
*
tile_n_size
;
int4
*
sh_ptr
=
sh_pipe_ptr
+
stage_size
*
pipe
;
if
constexpr
(
has_perm
)
{
if
(
threadIdx
.
x
<
stage_size
)
{
int
k_id
=
threadIdx
.
x
/
stage_n_threads
;
int
n_id
=
threadIdx
.
x
%
stage_n_threads
;
uint32_t
const
*
sh_perm_int_ptr
=
reinterpret_cast
<
uint32_t
const
*>
(
sh_perm_ptr
);
int
src_k
=
sh_perm_int_ptr
[
k_id
];
int
src_k_packed
=
src_k
/
pack_factor
;
cp_async4
(
&
sh_ptr
[
k_id
*
stage_n_threads
+
n_id
],
reinterpret_cast
<
int4
const
*>
(
&
(
b_q_weight_ptr
[
src_k_packed
*
size_n
+
first_n
+
(
n_id
*
4
)])));
}
}
else
{
if
(
threadIdx
.
x
<
stage_size
)
{
int
k_id
=
threadIdx
.
x
/
stage_n_threads
;
int
n_id
=
threadIdx
.
x
%
stage_n_threads
;
int
first_k
=
k_tile_id
*
tile_k_size
;
int
first_k_packed
=
first_k
/
pack_factor
;
cp_async4
(
&
sh_ptr
[
k_id
*
stage_n_threads
+
n_id
],
reinterpret_cast
<
int4
const
*>
(
&
(
b_q_weight_ptr
[(
first_k_packed
+
k_id
)
*
size_n
+
first_n
+
(
n_id
*
4
)])));
}
}
cp_async_fence
();
};
auto
repack_tile
=
[
&
](
int
pipe
,
int
k_tile_id
,
int
n_tile_id
)
{
if
(
n_tile_id
>=
n_tiles
)
{
return
;
}
int
warp_id
=
threadIdx
.
x
/
32
;
int
th_id
=
threadIdx
.
x
%
32
;
if
(
warp_id
>=
4
)
{
return
;
}
int
tc_col
=
th_id
/
4
;
int
tc_row
=
(
th_id
%
4
)
*
2
;
constexpr
int
tc_offsets
[
4
]
=
{
0
,
1
,
8
,
9
};
int
cur_n
=
warp_id
*
16
+
tc_col
;
constexpr
int
sh_stride
=
64
;
constexpr
uint32_t
mask
=
(
1
<<
num_bits
)
-
1
;
int4
*
sh_stage_ptr
=
sh_pipe_ptr
+
stage_size
*
pipe
;
uint32_t
*
sh_stage_int_ptr
=
reinterpret_cast
<
uint32_t
*>
(
sh_stage_ptr
);
uint32_t
*
sh_perm_int_ptr
=
reinterpret_cast
<
uint32_t
*>
(
sh_perm_ptr
);
uint32_t
vals
[
8
];
if
constexpr
(
has_perm
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
k_idx
=
tc_row
+
tc_offsets
[
i
];
uint32_t
src_k
=
sh_perm_int_ptr
[
k_idx
];
uint32_t
src_k_pos
=
src_k
%
pack_factor
;
uint32_t
b1_val
=
sh_stage_int_ptr
[
k_idx
*
sh_stride
+
cur_n
];
uint32_t
b1_cur_val
=
(
b1_val
>>
(
src_k_pos
*
num_bits
))
&
mask
;
uint32_t
b2_val
=
sh_stage_int_ptr
[
k_idx
*
sh_stride
+
cur_n
+
8
];
uint32_t
b2_cur_val
=
(
b2_val
>>
(
src_k_pos
*
num_bits
))
&
mask
;
vals
[
i
]
=
b1_cur_val
;
vals
[
4
+
i
]
=
b2_cur_val
;
}
}
else
{
uint32_t
b1_vals
[
tile_ints
];
uint32_t
b2_vals
[
tile_ints
];
#pragma unroll
for
(
int
i
=
0
;
i
<
tile_ints
;
i
++
)
{
b1_vals
[
i
]
=
sh_stage_int_ptr
[
cur_n
+
sh_stride
*
i
];
b2_vals
[
i
]
=
sh_stage_int_ptr
[
cur_n
+
8
+
sh_stride
*
i
];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
cur_elem
=
tc_row
+
tc_offsets
[
i
];
int
cur_int
=
cur_elem
/
pack_factor
;
int
cur_pos
=
cur_elem
%
pack_factor
;
vals
[
i
]
=
(
b1_vals
[
cur_int
]
>>
(
cur_pos
*
num_bits
))
&
mask
;
vals
[
4
+
i
]
=
(
b2_vals
[
cur_int
]
>>
(
cur_pos
*
num_bits
))
&
mask
;
}
}
constexpr
int
tile_size
=
tile_k_size
*
tile_n_size
/
pack_factor
;
int
out_offset
=
(
k_tile_id
*
n_tiles
+
n_tile_id
)
*
tile_size
;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if
constexpr
(
num_bits
==
4
)
{
constexpr
int
pack_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
uint32_t
res
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
res
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
4
);
}
out_ptr
[
out_offset
+
th_id
*
4
+
warp_id
]
=
res
;
}
else
{
constexpr
int
pack_idx
[
4
]
=
{
0
,
2
,
1
,
3
};
uint32_t
res1
=
0
;
uint32_t
res2
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
res1
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
8
);
res2
|=
vals
[
4
+
pack_idx
[
i
]]
<<
(
i
*
8
);
}
out_ptr
[
out_offset
+
th_id
*
8
+
(
warp_id
*
2
)
+
0
]
=
res1
;
out_ptr
[
out_offset
+
th_id
*
8
+
(
warp_id
*
2
)
+
1
]
=
res2
;
}
};
auto
start_pipes
=
[
&
](
int
k_tile_id
,
int
n_tile_id
)
{
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
-
1
;
pipe
++
)
{
fetch_to_shared
(
pipe
,
k_tile_id
,
n_tile_id
+
pipe
);
}
wait_for_stage
();
};
#pragma unroll
for
(
int
k_tile_id
=
start_k_tile
;
k_tile_id
<
finish_k_tile
;
k_tile_id
++
)
{
int
n_tile_id
=
0
;
if
constexpr
(
has_perm
)
{
load_perm_to_shared
(
k_tile_id
);
}
start_pipes
(
k_tile_id
,
n_tile_id
);
while
(
n_tile_id
<
n_tiles
)
{
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
;
pipe
++
)
{
fetch_to_shared
((
pipe
+
repack_stages
-
1
)
%
repack_stages
,
k_tile_id
,
n_tile_id
+
pipe
+
repack_stages
-
1
);
repack_tile
(
pipe
,
k_tile_id
,
n_tile_id
+
pipe
);
wait_for_stage
();
}
n_tile_id
+=
repack_stages
;
}
}
}
}
// namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK
(
size_k
%
marlin
::
tile_k_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_k_size = "
,
marlin
::
tile_k_size
);
TORCH_CHECK
(
size_n
%
marlin
::
tile_n_size
==
0
,
"size_n = "
,
size_n
,
" is not divisible by tile_n_size = "
,
marlin
::
tile_n_size
);
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
int
const
pack_factor
=
32
/
num_bits
;
// Verify B
TORCH_CHECK
((
size_k
/
pack_factor
)
==
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
", size_k = "
,
size_k
,
", pack_factor = "
,
pack_factor
);
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
==
size_n
,
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
" is not size_n = "
,
size_n
);
// Verify device and strides
TORCH_CHECK
(
b_q_weight
.
device
().
is_cuda
(),
"b_q_weight is not on GPU"
);
TORCH_CHECK
(
b_q_weight
.
is_contiguous
(),
"b_q_weight is not contiguous"
);
TORCH_CHECK
(
b_q_weight
.
dtype
()
==
at
::
kInt
,
"b_q_weight type is not kInt"
);
TORCH_CHECK
(
perm
.
device
().
is_cuda
(),
"perm is not on GPU"
);
TORCH_CHECK
(
perm
.
is_contiguous
(),
"perm is not contiguous"
);
TORCH_CHECK
(
perm
.
dtype
()
==
at
::
kInt
,
"perm type is not at::kInt"
);
// Alloc buffers
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
b_q_weight
));
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
b_q_weight
.
dtype
())
.
device
(
b_q_weight
.
device
());
torch
::
Tensor
out
=
torch
::
empty
(
{
size_k
/
marlin
::
tile_size
,
size_n
*
marlin
::
tile_size
/
pack_factor
},
options
);
// Detect if there is act_order
bool
has_perm
=
perm
.
size
(
0
)
!=
0
;
// Get ptrs
uint32_t
const
*
b_q_weight_ptr
=
reinterpret_cast
<
uint32_t
const
*>
(
b_q_weight
.
data_ptr
());
uint32_t
const
*
perm_ptr
=
reinterpret_cast
<
uint32_t
const
*>
(
perm
.
data_ptr
());
uint32_t
*
out_ptr
=
reinterpret_cast
<
uint32_t
*>
(
out
.
data_ptr
());
// Get dev info
int
dev
=
b_q_weight
.
get_device
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
dev
);
int
blocks
;
cudaDeviceGetAttribute
(
&
blocks
,
cudaDevAttrMultiProcessorCount
,
dev
);
int
max_shared_mem
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
if
(
false
)
{
}
CALL_IF
(
4
,
false
)
CALL_IF
(
4
,
true
)
CALL_IF
(
8
,
false
)
CALL_IF
(
8
,
true
)
else
{
TORCH_CHECK
(
false
,
"Unsupported repack config: num_bits = "
,
num_bits
,
", has_perm = "
,
has_perm
);
}
return
out
;
}
#endif
server/marlin/marlin_kernels/marlin.cuh
deleted
100644 → 0
View file @
583d37a2
#pragma once
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
namespace
marlin
{
// Marlin params
// 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles.
static
constexpr
int
default_threads
=
256
;
static
constexpr
int
pipe_stages
=
4
;
// 4 pipeline stages fit into shared memory
static
constexpr
int
min_thread_n
=
64
;
static
constexpr
int
min_thread_k
=
64
;
static
constexpr
int
tile_size
=
16
;
static
constexpr
int
max_par
=
16
;
// Repack params
static
constexpr
int
repack_stages
=
8
;
static
constexpr
int
repack_threads
=
256
;
static
constexpr
int
tile_k_size
=
tile_size
;
static
constexpr
int
tile_n_size
=
tile_k_size
*
4
;
// Helpers
template
<
typename
T
,
int
n
>
struct
Vec
{
T
elems
[
n
];
__device__
T
&
operator
[](
int
i
)
{
return
elems
[
i
];
}
};
using
I4
=
Vec
<
int
,
4
>
;
constexpr
int
div_ceil
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async
#else
__device__
inline
void
cp_async4_pred
(
void
*
smem_ptr
,
const
void
*
glob_ptr
,
bool
pred
=
true
)
{
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %0, 0;
\n
"
" @p cp.async.cg.shared.global [%1], [%2], %3;
\n
"
"}
\n
"
::
"r"
((
int
)
pred
),
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
__device__
inline
void
cp_async4
(
void
*
smem_ptr
,
const
void
*
glob_ptr
)
{
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" cp.async.cg.shared.global [%0], [%1], %2;
\n
"
"}
\n
"
::
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
__device__
inline
void
cp_async_fence
()
{
asm
volatile
(
"cp.async.commit_group;
\n
"
::
);
}
template
<
int
n
>
__device__
inline
void
cp_async_wait
()
{
asm
volatile
(
"cp.async.wait_group %0;
\n
"
::
"n"
(
n
));
}
#endif
}
// namespace marlin
server/marlin/marlin_kernels/marlin_cuda_kernel.cu
deleted
100644 → 0
View file @
583d37a2
This diff is collapsed.
Click to expand it.
server/marlin/marlin_kernels/marlin_dtypes.cuh
deleted
100644 → 0
View file @
583d37a2
#ifndef _data_types_cuh
#define _data_types_cuh
#include "marlin.cuh"
#include <cuda_fp16.h>
#include <cuda_bf16.h>
namespace
marlin
{
template
<
typename
scalar_t
>
class
ScalarType
{};
template
<
>
class
ScalarType
<
half
>
{
public:
using
scalar_t
=
half
;
using
scalar_t2
=
half2
;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using
FragA
=
Vec
<
half2
,
4
>
;
using
FragB
=
Vec
<
half2
,
2
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
half2
,
1
>
;
using
FragZP
=
Vec
<
half2
,
4
>
;
static
__device__
float
inline
num2float
(
const
half
x
)
{
return
__half2float
(
x
);
}
static
__device__
half2
inline
num2num2
(
const
half
x
)
{
return
__half2half2
(
x
);
}
static
__device__
half2
inline
nums2num2
(
const
half
x1
,
const
half
x2
)
{
return
__halves2half2
(
x1
,
x2
);
}
static
__host__
__device__
half
inline
float2num
(
const
float
x
)
{
return
__float2half
(
x
);
}
};
template
<
>
class
ScalarType
<
nv_bfloat16
>
{
public:
using
scalar_t
=
nv_bfloat16
;
using
scalar_t2
=
nv_bfloat162
;
using
FragA
=
Vec
<
nv_bfloat162
,
4
>
;
using
FragB
=
Vec
<
nv_bfloat162
,
2
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
nv_bfloat162
,
1
>
;
using
FragZP
=
Vec
<
nv_bfloat162
,
4
>
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static
__device__
float
inline
num2float
(
const
nv_bfloat16
x
)
{
return
__bfloat162float
(
x
);
}
static
__device__
nv_bfloat162
inline
num2num2
(
const
nv_bfloat16
x
)
{
return
__bfloat162bfloat162
(
x
);
}
static
__device__
nv_bfloat162
inline
nums2num2
(
const
nv_bfloat16
x1
,
const
nv_bfloat16
x2
)
{
return
__halves2bfloat162
(
x1
,
x2
);
}
static
__host__
__device__
nv_bfloat16
inline
float2num
(
const
float
x
)
{
return
__float2bfloat16
(
x
);
}
#endif
};
}
// namespace marlin
#endif
server/marlin/marlin_kernels/py.typed
deleted
100644 → 0
View file @
583d37a2
server/marlin/marlin_kernels/sparse/common/base.h
deleted
100644 → 0
View file @
583d37a2
/*
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
* Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
namespace
marlin_24
{
constexpr
int
ceildiv
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template
<
typename
T
,
int
n
>
struct
Vec
{
T
elems
[
n
];
__device__
T
&
operator
[](
int
i
)
{
return
elems
[
i
];
}
};
template
<
int
M_
,
int
N_
,
int
K_
>
struct
ShapeBase
{
static
constexpr
int
M
=
M_
,
N
=
N_
,
K
=
K_
;
};
using
I4
=
Vec
<
int
,
4
>
;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using
FragA
=
Vec
<
half2
,
4
>
;
using
FragB
=
Vec
<
half2
,
2
>
;
using
FragM
=
Vec
<
uint
,
1
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
half2
,
1
>
;
// quantization scales
}
// namespace marlin_24
server/marlin/marlin_kernels/sparse/common/mem.h
deleted
100644 → 0
View file @
583d37a2
/*
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
* Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "base.h"
namespace
marlin_24
{
// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__
inline
void
cp_async4_pred_zfill
(
void
*
smem_ptr
,
const
void
*
glob_ptr
,
bool
pred
=
true
,
const
bool
zfill
=
false
)
{
const
int
BYTES
=
16
;
int
src_in_bytes
=
(
zfill
?
0
:
BYTES
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %0, 0;
\n
"
" @p cp.async.cg.shared.global [%1], [%2], %3;
\n
"
"}
\n
"
::
"r"
((
int
)
pred
),
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
),
"r"
(
src_in_bytes
));
}
__device__
inline
void
cp_async4_pred
(
void
*
smem_ptr
,
const
void
*
glob_ptr
,
bool
pred
=
true
)
{
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %0, 0;
\n
"
" @p cp.async.cg.shared.global [%1], [%2], %3;
\n
"
"}
\n
"
::
"r"
((
int
)
pred
),
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
// Asynchronous global->shared copy
__device__
inline
void
cp_async4
(
void
*
smem_ptr
,
const
void
*
glob_ptr
)
{
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" cp.async.cg.shared.global [%0], [%1], %2;
\n
"
"}
\n
"
::
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
// Async copy fence.
__device__
inline
void
cp_async_fence
()
{
asm
volatile
(
"cp.async.commit_group;
\n
"
::
);
}
// Wait until at most `n` async copy stages are still pending.
template
<
int
n
>
__device__
inline
void
cp_async_wait
()
{
asm
volatile
(
"cp.async.wait_group %0;
\n
"
::
"n"
(
n
));
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__
inline
void
ldsm4
(
FragA
&
frag_a
,
const
void
*
smem_ptr
)
{
uint32_t
*
a
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_a
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
:
"=r"
(
a
[
0
]),
"=r"
(
a
[
1
]),
"=r"
(
a
[
2
]),
"=r"
(
a
[
3
])
:
"r"
(
smem
));
}
__device__
inline
void
ldsm4_m
(
FragM
&
frag_m
,
const
void
*
smem_ptr
)
{
uint32_t
*
a
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_m
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];
\n
"
:
"=r"
(
a
[
0
]),
"=r"
(
a
[
1
])
:
"r"
(
smem
));
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__
inline
void
ldsm4_t
(
FragA
&
frag_a
,
const
void
*
smem_ptr
)
{
uint32_t
*
a
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_a
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
:
"=r"
(
a
[
0
]),
"=r"
(
a
[
1
]),
"=r"
(
a
[
2
]),
"=r"
(
a
[
3
])
:
"r"
(
smem
));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__
inline
void
barrier_acquire
(
int
*
lock
,
int
count
)
{
if
(
threadIdx
.
x
==
0
)
{
int
state
=
-
1
;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
while
(
state
!=
count
);
}
__syncthreads
();
}
// Release barrier and increment visitation count.
__device__
inline
void
barrier_release
(
int
*
lock
,
bool
reset
=
false
)
{
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
reset
)
{
lock
[
0
]
=
0
;
return
;
}
int
val
=
1
;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm
volatile
(
"fence.acq_rel.gpu;
\n
"
);
asm
volatile
(
"red.relaxed.gpu.global.add.s32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
val
));
}
}
}
// namespace marlin_24
server/marlin/marlin_kernels/sparse/common/mma.h
deleted
100644 → 0
View file @
583d37a2
/*
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
* Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "base.h"
#include <cudaTypedefs.h>
namespace
marlin_24
{
// On CUDA earlier than 12.5, the ordered_metadata version of this instruction
// is not supported. On later versions of CUDA the version without ordered
// metadata results in the following warning:
// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction
// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially
// | reduced performance on some future architectures
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
#define MMA_SP_INST \
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
#else
#define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
#endif
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
__device__
inline
void
mma_sp
(
const
FragB
&
a_frag0
,
const
FragB
&
a_frag1
,
const
FragA
&
frag_b
,
FragC
&
frag_c
,
FragM
&
frag_m
,
const
int
psel
)
{
const
uint32_t
*
a0
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag0
);
const
uint32_t
*
a1
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag1
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
const
uint32_t
*
e
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_m
);
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
if
(
psel
==
0
)
{
asm
volatile
(
MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x0;
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a0
[
0
]),
"r"
(
a1
[
0
]),
"r"
(
a0
[
1
]),
"r"
(
a1
[
1
]),
"r"
(
b
[
0
]),
"r"
(
b
[
2
]),
"r"
(
b
[
4
]),
"r"
(
b
[
6
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]),
"r"
(
e
[
0
]));
asm
volatile
(
MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x0;
\n
"
:
"=f"
(
c
[
4
]),
"=f"
(
c
[
5
]),
"=f"
(
c
[
6
]),
"=f"
(
c
[
7
])
:
"r"
(
a0
[
0
]),
"r"
(
a1
[
0
]),
"r"
(
a0
[
1
]),
"r"
(
a1
[
1
]),
"r"
(
b
[
1
]),
"r"
(
b
[
3
]),
"r"
(
b
[
5
]),
"r"
(
b
[
7
]),
"f"
(
c
[
4
]),
"f"
(
c
[
5
]),
"f"
(
c
[
6
]),
"f"
(
c
[
7
]),
"r"
(
e
[
0
]));
}
else
{
asm
volatile
(
MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x1;
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a0
[
0
]),
"r"
(
a1
[
0
]),
"r"
(
a0
[
1
]),
"r"
(
a1
[
1
]),
"r"
(
b
[
0
]),
"r"
(
b
[
2
]),
"r"
(
b
[
4
]),
"r"
(
b
[
6
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]),
"r"
(
e
[
0
]));
asm
volatile
(
MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x1;
\n
"
:
"=f"
(
c
[
4
]),
"=f"
(
c
[
5
]),
"=f"
(
c
[
6
]),
"=f"
(
c
[
7
])
:
"r"
(
a0
[
0
]),
"r"
(
a1
[
0
]),
"r"
(
a0
[
1
]),
"r"
(
a1
[
1
]),
"r"
(
b
[
1
]),
"r"
(
b
[
3
]),
"r"
(
b
[
5
]),
"r"
(
b
[
7
]),
"f"
(
c
[
4
]),
"f"
(
c
[
5
]),
"f"
(
c
[
6
]),
"f"
(
c
[
7
]),
"r"
(
e
[
0
]));
}
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template
<
int
lut
>
__device__
inline
int
lop3
(
int
a
,
int
b
,
int
c
)
{
int
res
;
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
),
"n"
(
lut
));
return
res
;
}
__device__
__forceinline__
uint2
to_half4
(
float
c0
,
float
c1
,
float
c2
,
float
c3
)
{
uint2
r
;
asm
(
"{
\n\t
"
".reg .f16 a, b, c, d;
\n\t
"
"cvt.rn.f16.f32 a, %2;
\n\t
"
"cvt.rn.f16.f32 b, %3;
\n\t
"
"cvt.rn.f16.f32 c, %4;
\n\t
"
"cvt.rn.f16.f32 d, %5;
\n\t
"
"mov.b32 %0, {a, b};
\n\t
"
"mov.b32 %1, {c, d};
\n\t
"
"}"
:
"=r"
(
r
.
x
),
"=r"
(
r
.
y
)
:
"f"
(
c0
),
"f"
(
c1
),
"f"
(
c2
),
"f"
(
c3
));
return
r
;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template
<
int
start_byte
,
int
mask
>
__device__
inline
uint32_t
prmt
(
uint32_t
a
)
{
uint32_t
res
;
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"n"
(
start_byte
),
"n"
(
mask
));
return
res
;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__
inline
FragB
dequant_4bit
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const
int
SUB
=
0x64086408
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd480d480
;
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
return
frag_b
;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__
inline
FragB
dequant_8bit
(
int
q
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
uint32_t
lo
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_01
>
(
q
);
uint32_t
hi
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_23
>
(
q
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
__device__
inline
void
scale
(
FragB
&
frag_b
,
FragS
&
frag_s
,
int
i
)
{
half2
s
=
__half2half2
(
reinterpret_cast
<
__half
*>
(
&
frag_s
)[
i
]);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
}
__device__
inline
void
scale_floats
(
float
*
c0
,
float
*
c1
,
float
*
c2
,
float
*
c3
,
FragS
&
s0
,
float
*
c4
,
float
*
c5
,
float
*
c6
,
float
*
c7
,
FragS
&
s1
)
{
*
c0
=
__fmul_rn
(
*
c0
,
__half2float
(
s0
[
0
].
x
));
*
c1
=
__fmul_rn
(
*
c1
,
__half2float
(
s0
[
0
].
y
));
*
c2
=
__fmul_rn
(
*
c2
,
__half2float
(
s0
[
1
].
x
));
*
c3
=
__fmul_rn
(
*
c3
,
__half2float
(
s0
[
1
].
y
));
*
c4
=
__fmul_rn
(
*
c4
,
__half2float
(
s1
[
0
].
x
));
*
c5
=
__fmul_rn
(
*
c5
,
__half2float
(
s1
[
0
].
y
));
*
c6
=
__fmul_rn
(
*
c6
,
__half2float
(
s1
[
1
].
x
));
*
c7
=
__fmul_rn
(
*
c7
,
__half2float
(
s1
[
1
].
y
));
}
}
// namespace marlin_24
server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu
deleted
100644 → 0
View file @
583d37a2
This diff is collapsed.
Click to expand it.
server/marlin/setup.py
deleted
100644 → 0
View file @
583d37a2
from
setuptools
import
setup
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
extra_compile_args
=
[]
setup
(
name
=
"marlin_kernels"
,
ext_modules
=
[
CUDAExtension
(
name
=
"marlin_kernels"
,
sources
=
[
"marlin_kernels/awq_marlin_repack.cu"
,
"marlin_kernels/fp8_marlin.cu"
,
"marlin_kernels/gptq_marlin.cu"
,
"marlin_kernels/gptq_marlin_repack.cu"
,
"marlin_kernels/marlin_cuda_kernel.cu"
,
"marlin_kernels/sparse/marlin_24_cuda_kernel.cu"
,
"marlin_kernels/ext.cpp"
,
],
extra_compile_args
=
extra_compile_args
,
),
],
cmdclass
=
{
"build_ext"
:
BuildExtension
},
)
server/poetry.lock
View file @
922732b2
...
@@ -1139,6 +1139,74 @@ files = [
...
@@ -1139,6 +1139,74 @@ files = [
{file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"},
{file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"},
]
]
[[package]]
name = "marlin-kernels"
version = "0.2.0"
description = "Marlin quantization kernels"
optional = true
python-versions = ">=3.7"
files = [
{file = "marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:9a5afcf19b0f5917e43353cc19873fb3c4d4d0b924e2a95a37884f9ce208d0bd"},
]
[package.dependencies]
torch = "*"
[package.source]
type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
[[package]]
name = "marlin-kernels"
version = "0.2.0"
description = "Marlin quantization kernels"
optional = true
python-versions = ">=3.7"
files = [
{file = "marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:1e64fcc7ebadfaffa60091ee9201ae3daaf5c1be3be60c8c054143a3dcb72d5d"},
]
[package.dependencies]
torch = "*"
[package.source]
type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
[[package]]
name = "marlin-kernels"
version = "0.2.0"
description = "Marlin quantization kernels"
optional = true
python-versions = ">=3.7"
files = [
{file = "marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:e75f3ce9b1c13a4ed43a380d88e1d34d297259452db037ec1973ec33dc2eb78e"},
]
[package.dependencies]
torch = "*"
[package.source]
type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
[[package]]
name = "marlin-kernels"
version = "0.2.0"
description = "Marlin quantization kernels"
optional = true
python-versions = ">=3.7"
files = [
{file = "marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:2f99a27f70b391887ee6adffeeee7c3f4df7fac37393f9fb16d4cace2b3f6457"},
]
[package.dependencies]
torch = "*"
[package.source]
type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
[[package]]
[[package]]
name = "mpmath"
name = "mpmath"
version = "1.3.0"
version = "1.3.0"
...
@@ -3507,6 +3575,7 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
...
@@ -3507,6 +3575,7 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
[extras]
[extras]
accelerate = ["accelerate"]
accelerate = ["accelerate"]
bnb = ["bitsandbytes"]
bnb = ["bitsandbytes"]
marlin = ["marlin-kernels", "marlin-kernels", "marlin-kernels", "marlin-kernels"]
outlines = ["outlines"]
outlines = ["outlines"]
peft = ["peft"]
peft = ["peft"]
quantize = ["accelerate", "datasets", "texttable"]
quantize = ["accelerate", "datasets", "texttable"]
...
@@ -3515,4 +3584,4 @@ torch = ["torch"]
...
@@ -3515,4 +3584,4 @@ torch = ["torch"]
[metadata]
[metadata]
lock-version = "2.0"
lock-version = "2.0"
python-versions = ">=3.9,<3.13"
python-versions = ">=3.9,<3.13"
content-hash = "
c94bbdf8131750891fb3f7132066718534129d85a4c09126d8d01c2de6c72798
"
content-hash = "
a89867b23017d2efa8a7aa14d4764bcbd3b4dea9bfbf06a7a68464cb184ac6a1
"
server/pyproject.toml
View file @
922732b2
...
@@ -40,10 +40,18 @@ py-cpuinfo = "^9.0.0"
...
@@ -40,10 +40,18 @@ py-cpuinfo = "^9.0.0"
# Remove later, temporary workaround for outlines.
# Remove later, temporary workaround for outlines.
numpy
=
"^1.26"
numpy
=
"^1.26"
marlin-kernels
=
[
{
url
=
"https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
,
python
=
"~3.9"
,
optional
=
true
}
,
{
url
=
"https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
,
python
=
"~3.10"
,
optional
=
true
}
,
{
url
=
"https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
,
python
=
"~3.11"
,
optional
=
true
}
,
{
url
=
"https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
,
python
=
"~3.12"
,
optional
=
true
}
,
]
[tool.poetry.extras]
[tool.poetry.extras]
torch
=
["torch"]
torch
=
["torch"]
accelerate
=
["accelerate"]
accelerate
=
["accelerate"]
bnb
=
["bitsandbytes"]
bnb
=
["bitsandbytes"]
marlin
=
["marlin-kernels"]
peft
=
["peft"]
peft
=
["peft"]
quantize
=
[
"texttable"
,
"datasets"
,
"accelerate"
]
quantize
=
[
"texttable"
,
"datasets"
,
"accelerate"
]
outlines
=
["outlines"]
outlines
=
["outlines"]
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment