"vscode:/vscode.git/clone" did not exist on "f7ee993fdc8deeb4f91b350f2362ac5aaa7e5ab9"
Unverified Commit 81e6345d authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

LLM.int8() Refactoring: Part 1 (#1401)



* Start of int8 refactor: remove col32/col_ampere/col_turing transforms in new igemmlt implementation

* Fix unintended change

* New naive mm_dequant kernel for row-major; cleanup

* fix

* int8 refactor: initial sparse decomp, cleanup

* Int8 refactoring: remove separate NO_CUBLASLT build; more cleanup

* int8: inference optimizations, some cleanup

* int8: more tests passing, cleanup

* int8 - more cleanup, most tests passing

* int8: specify CUDA stream for int8 ops

* perf: reduce overhead from getting cudaStream ptr

* Mark some functions for deprecation.

* int8 sparse decomp: small perf improvement

* update setup.py

* Update bitsandbytes/autograd/_functions.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/research/autograd/_functions.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* int8 - perf improvement for sparse decomposition inference; deprecate get_tensor_stream() in favor of new private fn

* int8 cleanup

* Ignore ruff rule ISC001 (incompatible with formatter)

* add comment

* int8 more cleanup

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* int8: rename / deprecate old fn signatures

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* type annotation

* format update

* Update bitsandbytes/research/autograd/_functions.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* cleanup

* Add comment to explain division optimization

* more cleanup

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* cleanup

* Type annotations, cleanup

* remove unused kernels; improved type annotations

* small perf optimization for single-GPU systems

* small perf optimization for single-GPU systems

* update docstrings

* Improve docs and tests

* Update docstring

* Update test

* add benchmarking script

* test cleanup: add deprecated marker, move benchmarks out

* Add int8 dequant function; misc improvements

* int8 matmul fallback for inner dims not divisible by 4

* improve register usage of kInt8VectorQuant - especially for A100/H100

* disable fail-fast for package build

* maxwell compat

* ptxas verbose

* docs update

* doc update

* backward fix

* Bugfix sparse decomp

* Int8 fix for PEFT OLoRA init

* Fix test for deprecated spmm_coo

* test improvement

* doc update

* typo

* doc cleanup

* docs

* add inference benchmark script

* Add benchmarks, doc update

---------
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>
parent 7dca7004
......@@ -314,8 +314,6 @@ int roundoff(int v, int d) {
}
#ifdef NO_CUBLASLT
#else
template<int ORDER> cublasLtOrder_t get_order()
{
switch(ORDER)
......@@ -347,7 +345,6 @@ template cublasLtOrder_t get_order<COL>();
template cublasLtOrder_t get_order<COL32>();
template cublasLtOrder_t get_order<COL_TURING>();
template cublasLtOrder_t get_order<COL_AMPERE>();
#endif
template<int ORDER> int get_leading_dim(int dim1, int dim2)
......@@ -379,8 +376,6 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2)
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
{
#ifdef NO_CUBLASLT
#else
cublasLtOrder_t orderA = get_order<SRC>();
cublasLtOrder_t orderOut = get_order<TARGET>();
int ldA = get_leading_dim<SRC>(dim1, dim2);
......@@ -419,69 +414,98 @@ template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void trans
if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc));
if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc));
if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc));
#endif
}
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{
#ifdef NO_CUBLASLT
return ERR_NOT_IMPLEMENTED;
#else
int has_error = 0;
cublasLtMatmulDesc_t matmulDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cublasOperation_t opT = CUBLAS_OP_T;
cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO;
cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32;
cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C;
cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4;
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, m, k, lda));
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, n, k, ldb));
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
if(FORMATB == COL_TURING)
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing)));
else
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere)));
template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
cublasLtHandle_t ltHandle,
int m, int n, int k,
const int8_t * A,
const int8_t * B,
void * C,
float * row_scale,
int lda, int ldb, int ldc,
cudaStream_t stream
) {
if(DTYPE_OUT == 32)
{
has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32I));
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc));
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
// Calculate C = A^T @ B, in col-major layout.
//
// Use the IMMA kernels requires:
// * A must be transposed and B must be non-transposed.
// * Dimensions m and k must be multiples of 4.
// * All pointers must be 4-byte aligned; 16-byte alignment preferred.
int has_error = 0;
cublasLtMatmulDesc_t matmulDesc;
cublasLtMatrixLayout_t aDesc, bDesc, cDesc;
cublasOperation_t opT = CUBLAS_OP_T;
cudaDataType_t outType = DTYPE_OUT == 32 ? CUDA_R_32I : CUDA_R_8I;
cudaDataType_t scaleType = DTYPE_OUT == 32 ? CUDA_R_32I : CUDA_R_32F;
cublasLtPointerMode_t pointerMode = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO;
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&aDesc, CUDA_R_8I, m, k, lda));
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&bDesc, CUDA_R_8I, m, n, ldb));
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&cDesc, outType, k, n, ldc));
// Default layout order is col major
has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, scaleType));
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT)));
if (DTYPE_OUT == 32) {
int alpha = 1, beta = 0;
has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, 0));
}
else
{
has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32F));
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_8I, m, n, ldc));
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
if(!SCALE_ROWS)
{
float alpha = 1.0f, beta = 0.0f;
has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0));
}
else
{
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec)));
has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0));
}
has_error |= checkCublasStatus(cublasLtMatmul(
ltHandle, matmulDesc,
&alpha, A, aDesc,
B, bDesc, &beta,
(int32_t*)C, cDesc,
(int32_t*)C, cDesc,
NULL, NULL, 0, stream
));
} else {
// This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows.
if (!SCALE_ROWS) {
float alpha = 1.0f, beta = 0.0f;
has_error |= checkCublasStatus(cublasLtMatmul(
ltHandle, matmulDesc,
&alpha, A, aDesc,
B, bDesc, &beta,
(int8_t*)C, cDesc,
(int8_t*)C, cDesc,
NULL, NULL, 0, stream
));
} else {
cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST;
float beta = 0.0f;
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(
matmulDesc,
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointerMode,
sizeof(alphaVec)
));
has_error |= checkCublasStatus(cublasLtMatmul(
ltHandle, matmulDesc,
row_scale, A, aDesc,
B, bDesc, &beta,
(int8_t*)C, cDesc,
(int8_t*)C, cDesc,
NULL, NULL, 0, stream
));
}
}
has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(cDesc));
has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(bDesc));
has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(aDesc));
has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc));
if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc));
if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc));
if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc));
if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc));
if(has_error == 1)
printf("error detected");
if(has_error == 1)
printf("error detected");
return has_error;
#endif // NO_CUBLASLT
return has_error;
}
int fill_up_to_nearest_multiple(int value, int multiple)
......@@ -489,64 +513,32 @@ int fill_up_to_nearest_multiple(int value, int multiple)
return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple)));
}
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half *bias, int numRows, int numCols)
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half *bias, int numRows, int numCols, cudaStream_t stream)
{
int threads = 512;
int tileCols = fill_up_to_nearest_multiple(numCols, 32);
int n = numRows*tileCols;
int subtile_rows = 128;
int tilesize = 32*subtile_rows;
int num_blocks = numRows/subtile_rows;
num_blocks += (numRows % subtile_rows == 0) ? 0 : 1;
num_blocks = num_blocks*(tileCols/32);
assert(threads <= tilesize);
kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n);
const int threads = 512;
const int num_per_thread = 4;
const int num_per_block = threads * num_per_thread;
const int n = numRows*numCols;
const int num_blocks = (n + num_per_block - 1) / num_per_block;
kdequant_mm_int32_fp16<num_per_thread, threads><<<num_blocks, threads, 0, stream>>>(A, rowStats, colStats, out, bias, numRows, numCols, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
#define STATS_THREADS 64
#define STATS_ITEMS 4
#define STATS_ROWS 16
void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols)
{
int tile_cols = STATS_THREADS*STATS_ITEMS;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
int row_tiles = (tiledRows/STATS_ROWS);
int col_tiles = (tiledCols/tile_cols);
row_tiles = row_tiles > 0 ? row_tiles : 1;
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;
if(nnz_threshold == 0.0)
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
else if(nnz_threshold != 0.0)
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 1><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
if (threshold == 0.0) {
kInt8VectorQuant<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, out, rowStats, threshold, rows, cols);
} else {
kInt8VectorQuant<half, 1024, 1><<<rows, 1024, 0, stream>>>(A, out, rowStats, threshold, rows, cols);
}
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols)
{
int threads = 64;
int items_per_thread = 4;
int tile_cols = threads*items_per_thread;
int tile_rows = 16;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
int row_tiles = (tiledRows/tile_rows);
int col_tiles = (tiledCols/tile_cols);
row_tiles = row_tiles > 0 ? row_tiles : 1;
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;
if(threshold > 0.0f)
kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
if (threshold == 0.0)
kgetRowStats<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols);
else
kDoubleRowColQuant<64, 4, 16, 64*4, 0><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
kgetRowStats<half, 1024, 1><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
......@@ -596,10 +588,6 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{
#ifdef NO_CUBLASLT
#else
cusparseSpMatDescr_t descA;
cusparseDnMatDescr_t descB, descC;
......@@ -646,7 +634,6 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val
CHECK_CUSPARSE( cusparseDestroyDnMat(descB) );
CHECK_CUSPARSE( cusparseDestroyDnMat(descC) );
CUDA_CHECK_RETURN( cudaFree(dBuffer) );
#endif
}
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
......@@ -766,12 +753,9 @@ template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx
template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template int igemmlt<COL_TURING, 32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template int igemmlt<COL_TURING, 8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template int igemmlt<COL_TURING, 8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template int igemmlt<COL_AMPERE, 32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template int igemmlt<COL_AMPERE, 8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template int igemmlt<COL_AMPERE, 8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
template void transformRowToFormat<COL32, 0>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL32, 1>(char * A, char *out, int rows, int cols);
......
......@@ -29,7 +29,6 @@
exit(1); \
} }
#define THREADS_PER_BLOCKS (512)
#define CHECK_CUSPARSE(value) { \
cusparseStatus_t _m_cudaStat = value; \
......@@ -40,9 +39,6 @@
} }
#define THREADS_PER_BLOCKS (512)
inline void checkCudaStatus(cudaError_t status) {
if (status != cudaSuccess) {
printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status));
......@@ -175,15 +171,13 @@ void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, i
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
long long int strideA, long long int strideB, long long int strideC, int batchCount);
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2);
void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols);
void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols);
void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed,
int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols);
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream);
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream);
void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream);
template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols);
......
......@@ -175,23 +175,15 @@ void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRo
void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_TURING>(A, idx, out, idx_size, rows, cols); }
void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_AMPERE>(A, idx, out, idx_size, rows, cols); }
int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_TURING, 32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int igemmlt_turing_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_TURING, 8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int igemmlt_turing_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_TURING, 8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int igemmlt_ampere_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_AMPERE, 32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int igemmlt_ampere_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_AMPERE, 8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int igemmlt_ampere_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_AMPERE, 8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
int igemmlt_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
int igemmlt_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive<half, 16>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
......@@ -316,25 +308,15 @@ extern "C"
Context *get_context(){ return new Context(); }
ContextCusparse *get_cusparse(){ return new ContextCusparse(); }
int cigemmlt_turing_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
//{ (cublasLtHandle_t)context->m_handle; return 0; }
//{ return 0; }//igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int cigemmlt_turing_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt_turing_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int cigemmlt_turing_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt_turing_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int cigemmlt_ampere_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt_ampere_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int cigemmlt_ampere_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int cigemmlt_ampere_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt_ampere_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int cigemmlt_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
return igemmlt_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
int cigemmlt_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
return igemmlt_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
#define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \
......@@ -351,13 +333,14 @@ extern "C"
MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8)
MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32)
void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols)
{ dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols); }
void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols)
{ getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); }
void cdouble_rowcol_quant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_row_ptr, float threshold, int rows, int cols)
{ doubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_row_ptr, threshold, rows, cols); }
void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream)
{ dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); }
void cget_row_stats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
getRowStats(A, rowStats, threshold, rows, cols, stream);
}
void cint8_vector_quant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream);
}
void ctransform_row2col32(char * A, char *out, int rows, int cols)
{ transform_row2col32(A, out, rows, cols); }
......
......@@ -32,6 +32,8 @@
title: Papers, resources & how to cite
- title: API reference
sections:
- title: Functional
local: reference/functional
- title: Optimizers
sections:
- local: reference/optim/optim_overview
......@@ -57,7 +59,7 @@
- title: k-bit quantizers
sections:
- local: reference/nn/linear8bit
title: 8-bit quantizer
title: LLM.int8()
- local: reference/nn/linear4bit
title: 4-bit quantizer
- local: reference/nn/embeddings
......
......@@ -5,7 +5,7 @@ This is an overview of the `bnb.functional` API in `bitsandbytes` that we think
## Using Int8 Matrix Multiplication
For straight Int8 matrix multiplication with mixed precision decomposition you can use ``bnb.matmul(...)``. To enable mixed precision decomposition, use the threshold parameter:
For straight Int8 matrix multiplication without mixed precision decomposition you can use ``bnb.matmul(...)``. To enable mixed precision decomposition, use the threshold parameter:
```py
bnb.matmul(..., threshold=6.0)
......
......@@ -49,7 +49,7 @@ Authors: Tim Dettmers, Luke Zettlemoyer
}
```
## [LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale (Nov 2022)](https://arxiv.org/abs/2208.07339)
## [LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale (Nov 2022)](https://arxiv.org/abs/2208.07339) [[llm-int8]]
Authors: Tim Dettmers, Mike Lewis, Younes Belkada, Luke Zettlemoyer
- [LLM.int8() Blog Post](https://huggingface.co/blog/hf-bitsandbytes-integration)
......
......@@ -3,7 +3,7 @@
bitsandbytes enables accessible large language models via k-bit quantization for PyTorch. bitsandbytes provides three main features for dramatically reducing memory consumption for inference and training:
* 8-bit optimizers uses block-wise quantization to maintain 32-bit performance at a small fraction of the memory cost.
* LLM.Int() or 8-bit quantization enables large language model inference with only half the required memory and without any performance degradation. This method is based on vector-wise quantization to quantize most features to 8-bits and separately treating outliers with 16-bit matrix multiplication.
* LLM.int8() or 8-bit quantization enables large language model inference with only half the required memory and without any performance degradation. This method is based on vector-wise quantization to quantize most features to 8-bits and separately treating outliers with 16-bit matrix multiplication.
* QLoRA or 4-bit quantization enables large language model training with several memory-saving techniques that don't compromise performance. This method quantizes a model to 4-bits and inserts a small set of trainable low-rank adaptation (LoRA) weights to allow training.
# License
......
......@@ -23,25 +23,28 @@ Welcome to the installation guide for the `bitsandbytes` library! This document
### Supported CUDA Configurations[[cuda-pip]]
The latest version of `bitsandbytes` builds on the following configurations:
The latest version of the distributed `bitsandbytes` package is built with the following configurations:
| **OS** | **CUDA Version** | **Compiler** |
| **OS** | **CUDA Toolkit** | **Host Compiler** |
|-------------|------------------|----------------------|
| **Linux** | 11.7 - 12.3 | GCC 11.4 |
| | 12.4 - 12.6 | GCC 13.2 |
| **Windows** | 11.7 - 12.6 | MSVC 19.42+ (VS2022) |
| | 12.4+ | GCC 13.2 |
| **Windows** | 11.7 - 12.6 | MSVC 19.38+ (VS2022) |
For Linux systems, ensure your hardware meets the following requirements:
For CUDA systems, ensure your hardware meets the following requirements:
| **Feature** | **Hardware Requirement** |
|---------------------------------|--------------------------------------------------------------------|
| LLM.int8() | NVIDIA Turing (RTX 20 series, T4) or Ampere (RTX 30 series, A4-A100) GPUs |
| 8-bit optimizers/quantization | NVIDIA Kepler (GTX 780 or newer) |
| **Feature** | **Minimum Hardware Requirement** |
|---------------------------------|---------------------------------------------------------------|
| LLM.int8() | NVIDIA Turing (RTX 20 series, T4) or newer GPUs |
| 8-bit optimizers/quantization | NVIDIA Maxwell (GTX 900 series, TITAN X, M40) or newer GPUs * |
| NF4/FP4 quantization | NVIDIA Maxwell (GTX 900 series, TITAN X, M40) or newer GPUs * |
> [!WARNING]
> `bitsandbytes >= 0.39.1` no longer includes Kepler binaries in pip installations. This requires [manual compilation using](#cuda-compile) the `cuda11x_nomatmul_kepler` configuration.
To install from PyPI.
> `bitsandbytes >= 0.45.0` no longer supports Kepler GPUs.
>
> Support for Maxwell GPUs is deprecated and will be removed in a future release. For the best results, a Turing generation device or newer is recommended.
```bash
pip install bitsandbytes
......@@ -79,7 +82,7 @@ For Linux and Windows systems, compiling from source allows you to customize the
<hfoptions id="source">
<hfoption id="Linux">
To compile from source, you need CMake >= **3.22.1** and Python >= **3.8** installed. Make sure you have a compiler installed to compile C++ (`gcc`, `make`, headers, etc.).
To compile from source, you need CMake >= **3.22.1** and Python >= **3.9** installed. Make sure you have a compiler installed to compile C++ (`gcc`, `make`, headers, etc.).
For example, to install a compiler and CMake on Ubuntu:
......
# Overview
The `bitsandbytes.functional` API provides the low-level building blocks for the library's features.
## When to Use `bitsandbytes.functional`
* When you need direct control over quantized operations and their parameters.
* To build custom layers or operations leveraging low-bit arithmetic.
* To integrate with other ecosystem tooling.
* For experimental or research purposes requiring non-standard quantization or performance optimizations.
## LLM.int8()
[[autodoc]] functional.int8_double_quant
[[autodoc]] functional.int8_linear_matmul
[[autodoc]] functional.int8_mm_dequant
[[autodoc]] functional.int8_vectorwise_dequant
[[autodoc]] functional.int8_vectorwise_quant
## 4-bit
[[autodoc]] functional.dequantize_4bit
[[autodoc]] functional.dequantize_fp4
[[autodoc]] functional.dequantize_nf4
[[autodoc]] functional.gemv_4bit
[[autodoc]] functional.quantize_4bit
[[autodoc]] functional.quantize_fp4
[[autodoc]] functional.quantize_nf4
[[autodoc]] functional.QuantState
## Dynamic 8-bit Quantization
Primitives used in the 8-bit optimizer quantization.
For more details see [8-Bit Approximations for Parallelism in Deep Learning](https://arxiv.org/abs/1511.04561)
[[autodoc]] functional.dequantize_blockwise
[[autodoc]] functional.quantize_blockwise
## Utility
[[autodoc]] functional.get_ptr
[[autodoc]] functional.is_on_gpu
# 8-bit quantization
# LLM.int8()
[LLM.int8()](https://hf.co/papers/2208.07339) is a quantization method that aims to make large language model inference more accessible without significant degradation. Unlike naive 8-bit quantization, which can result in loss of critical information and accuracy, LLM.int8() dynamically adapts to ensure sensitive components of the computation retain higher precision when needed. The key is to extract the outliers from the inputs and weights and multiply them in 16-bit. All other values are multiplied in 8-bit before being dequantized back to 16-bits. The outputs from the 16-bit and 8-bit multiplication are combined to produce the final output.
[LLM.int8()](https://hf.co/papers/2208.07339) is a quantization method that doesn't degrade performance which makes large model inference more accessible. The key is to extract the outliers from the inputs and weights and multiply them in 16-bit. All other values are multiplied in 8-bit and quantized to Int8 before being dequantized back to 16-bits. The outputs from the 16-bit and 8-bit multiplication are combined to produce the final output.
[Further Resources](../../explanations/resources#llm-int8)
## Linear8bitLt
......
......@@ -31,6 +31,7 @@ ignore = [
"E731", # Do not use lambda
"F841", # Local assigned but not used (TODO: enable, these are likely bugs)
"RUF012", # Mutable class attribute annotations
"ISC001", # single-line-implicit-string-concatenation incompatible with formatter
]
[tool.ruff.lint.extend-per-file-ignores]
......
......@@ -11,3 +11,4 @@ log_file = logs/pytest.log
markers =
benchmark: mark test as benchmark
slow: mark test as slow
deprecated: mark test as covering a deprecated feature
......@@ -31,10 +31,10 @@ setup(
description="k-bit optimizers and matrix multiplication routines.",
license="MIT",
keywords="gpu optimizers optimization 8-bit quantization compression",
url="https://github.com/TimDettmers/bitsandbytes",
url="https://github.com/bitsandbytes-foundation/bitsandbytes",
packages=find_packages(),
package_data={"": libs},
install_requires=["torch", "numpy"],
install_requires=["torch", "numpy", "typing_extensions>=4.8.0"],
extras_require={
"benchmark": ["pandas", "matplotlib"],
"test": ["scipy", "lion_pytorch"],
......
......@@ -7,10 +7,6 @@ import torch
def pytest_runtest_call(item):
try:
item.runtest()
except NotImplementedError as nie:
if "NO_CUBLASLT" in str(nie):
pytest.skip("CUBLASLT not available")
raise
except AssertionError as ae:
if str(ae) == "Torch not compiled with CUDA enabled":
pytest.skip("Torch not compiled with CUDA enabled")
......
......@@ -28,6 +28,7 @@ TRANSPOSE_VALS = [(False, True), (False, False)]
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
@pytest.mark.deprecated
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool], transpose: Tuple[bool, bool]):
if dim2 > 0:
dim2 = dim2 - (dim2 % 16)
......@@ -198,10 +199,10 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
assert (idx == 0).sum().item() < n * 0.02
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [48], ids=id_formatter("dim4"))
@pytest.mark.parametrize("decomp", [0.0, 6.0], ids=id_formatter("decomp"))
@pytest.mark.parametrize(
"funcs",
......@@ -249,13 +250,8 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
if not has_fp16_weights:
if not transpose[0] and not transpose[1]:
B2 = B2.t().contiguous()
(
state.CB,
CBt,
state.SCB,
SCBt,
coo_tensorB,
) = bnb.functional.double_quant(B2.to(torch.float16))
state.CB, state.SCB, _ = bnb.functional.int8_vectorwise_quant(B2.to(torch.float16))
B2 = state.CB
if not transpose[0] and transpose[1]:
......@@ -313,11 +309,13 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
else:
assert torch.abs(gradB1).sum() == 0.0
assert torch.abs(gradB2).sum() == 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.10
assert (idx == 0).sum().item() <= n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.02
torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
if req_grad[2]:
......
......@@ -13,15 +13,6 @@ def cuda120_spec() -> CUDASpecs:
)
@pytest.fixture
def cuda111_noblas_spec() -> CUDASpecs:
return CUDASpecs(
cuda_version_string="111",
highest_compute_capability=(7, 2),
cuda_version_tuple=(11, 1),
)
def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec):
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120"
......@@ -31,14 +22,3 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
monkeypatch.setenv("BNB_CUDA_VERSION", "110")
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110"
assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning?
def test_get_cuda_bnb_library_path_override_nocublaslt(monkeypatch, cuda111_noblas_spec, caplog):
monkeypatch.setenv("BNB_CUDA_VERSION", "125")
assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda125_nocublaslt"
assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning?
def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec):
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt"
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment