Unverified Commit 849d9449 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Deprecation cleanup (#1669)

* Deprecation cleanup: remove histogram_scatter_add_2d

* Deprecation cleanup: vectorwise_mm_dequant

* Deprecation cleanup: vectorwise_quant

* Remove unused test

* Optimizer test cleanup

* Deprecations: remove estimate_quantiles, create_quantile_map

* Move deprecated test
parent 76d3e2b1
...@@ -401,23 +401,6 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): ...@@ -401,23 +401,6 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
return torch.tensor(data, dtype=torch.float32) return torch.tensor(data, dtype=torch.float32)
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def create_quantile_map(A, total_bits=8):
q = estimate_quantiles(A, num_quantiles=2**total_bits - 1)
q = q.tolist()
q.append(0)
gap = 256 - len(q)
for i in range(gap):
q.append(0)
q.sort()
q = Tensor(q)
q = q / q.abs().max()
return q
def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]): def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
"""Verifies that the input tensors are all on the same device. """Verifies that the input tensors are all on the same device.
...@@ -474,74 +457,6 @@ def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: ...@@ -474,74 +457,6 @@ def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
return ct.c_void_p(A.data_ptr()) return ct.c_void_p(A.data_ptr())
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def estimate_quantiles(
A: Tensor,
out: Optional[torch.Tensor] = None,
offset: float = 1 / 512,
num_quantiles=256,
) -> Tensor:
"""
Estimates 256 equidistant quantiles on the input tensor eCDF.
Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
via the eCDF of the input tensor `A`. This is a fast but approximate algorithm
and the extreme quantiles close to 0 and 1 have high variance / large estimation
errors. These large errors can be avoided by using the offset variable which trims
the distribution. The default offset value of 1/512 ensures minimum entropy encoding -- it
trims 1/512 = 0.2% from each side of the distrivution. An offset value of 0.01 to 0.02
usually has a much lower error but is not a minimum entropy encoding. Given an offset
of 0.02 equidistance points in the range [0.02, 0.98] are used for the quantiles.
Parameters
----------
A : torch.Tensor
The input tensor. Any shape.
out : torch.Tensor
Tensor with the 256 estimated quantiles.
offset : float
The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles)
num_quantiles : int
The number of equally spaced quantiles.
Returns
-------
torch.Tensor:
The 256 quantiles in float32 datatype.
"""
if A.numel() < 256:
raise NotImplementedError(
f"Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.",
)
if num_quantiles > 256:
raise NotImplementedError(
f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}",
)
if num_quantiles < 256 and offset == 1 / (512):
# override default arguments
offset = 1 / (2 * num_quantiles)
if out is None:
out = torch.zeros((256,), dtype=torch.float32, device=A.device)
with _cuda_device_of(A):
is_on_gpu([A, out])
if A.dtype == torch.float32:
lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
elif A.dtype == torch.float16:
lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
else:
raise NotImplementedError(f"Not supported data type {A.dtype}")
if num_quantiles < 256:
step = round(256 / num_quantiles)
idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
out = out[idx]
return out
class QuantState: class QuantState:
"""container for quantization state components to work with Params4bit and similar classes""" """container for quantization state components to work with Params4bit and similar classes"""
...@@ -1601,25 +1516,6 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: ...@@ -1601,25 +1516,6 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
return current_gnorm, clip_value, gnorm_scale return current_gnorm, clip_value, gnorm_scale
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor):
assert len(histogram.shape) == 2
assert histogram.dtype == torch.float32
assert source.dtype == torch.float32
assert index1.dtype == torch.int32
assert index2.dtype == torch.int32
assert histogram.device.type == "cuda"
assert index1.device.type == "cuda"
assert index2.device.type == "cuda"
assert source.device.type == "cuda"
maxdim1 = ct.c_int32(histogram.shape[0])
n = ct.c_int32(index1.numel())
is_on_gpu([histogram, index1, index2, source])
lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
if not torch.cuda.is_initialized(): if not torch.cuda.is_initialized():
torch.cuda.init() torch.cuda.init()
...@@ -2426,118 +2322,6 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ...@@ -2426,118 +2322,6 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
C = 127.0 C = 127.0
@deprecated(
"This function is deprecated and will be removed in a future release. "
"Consider using `int8_vectorwise_quant` instead.",
category=FutureWarning,
)
def vectorwise_quant(x, dim=1, quant_type="vector"):
if quant_type == "linear":
max1 = torch.abs(x).max().float()
xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1
elif quant_type in ["vector", "row"]:
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
xq = torch.round(x * (C / max1)).to(torch.int8)
return xq, max1
elif quant_type == "zeropoint":
dtype = x.dtype
x = x.float()
dyna = x.max() - x.min()
if dyna == 0:
dyna = 1
qx = 255.0 / dyna
minx = x.min()
zpx = torch.round(minx * qx)
x = torch.round(qx * x - zpx) + zpx
return x, qx
elif quant_type in ["vector-zeropoint", "row-zeropoint"]:
dtype = x.dtype
x = x.float()
dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True)
dyna[dyna == 0] = 1
qx = 255.0 / dyna
minx = torch.amin(x, dim=dim, keepdim=True)
zpx = torch.round(minx * qx)
x = torch.round(qx * x - zpx) + zpx
return x, qx
elif quant_type == "truncated-vector":
with torch.no_grad():
absx = torch.abs(x)
max1 = torch.amax(absx, dim=dim, keepdim=True)
max1 = max1 * 0.7
idx = absx > max1.expand_as(absx)
sign = torch.sign(x[idx])
x[idx] = max1.expand_as(absx)[idx] * sign
xq = torch.round(x / max1 * C).to(torch.int8)
return xq, max1
else:
return None
@deprecated(
"This function is deprecated and will be removed in a future release.",
category=FutureWarning,
)
def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
if quant_type == "linear":
norm = S1 * S2 / (C * C)
# double cast needed to prevent overflows
return (xq.float() * norm).to(dtype)
elif quant_type == "zeropoint":
norm = 1.0 / (S1 * S2)
return (xq.float() * norm).to(dtype)
elif quant_type == "row-zeropoint":
norm = 1.0 / (S1 * S2)
x = xq.float()
if len(S1.shape) == 3 and len(x.shape) == 2:
S1 = S1.squeeze(0)
if len(S2.shape) == 3 and len(x.shape) == 2:
S2 = S2.squeeze(0)
if len(S1.shape) == 2:
x *= norm
else:
x *= norm
return x.to(dtype)
elif quant_type == "vector-zeropoint":
x = xq.float()
if len(S1.shape) == 3 and len(x.shape) == 2:
S1 = S1.squeeze(0)
if len(S2.shape) == 3 and len(x.shape) == 2:
S2 = S2.squeeze(0)
if len(S1.shape) == 2:
x *= 1.0 / S1
else:
x *= 1.0 / S1
x *= 1.0 / S2.t()
return x.to(dtype)
elif quant_type == "row":
x = xq.float()
if len(S1.shape) == 3 and len(x.shape) == 2:
S1 = S1.squeeze(0)
if len(S2.shape) == 3 and len(x.shape) == 2:
S2 = S2.squeeze(0)
if len(S1.shape) == 2:
x *= S1 * S2 / (C * C)
else:
x *= S1 * S2 / (C * C)
return x.to(dtype)
elif quant_type in ["truncated-vector", "vector"]:
x = xq.float()
if len(S1.shape) == 3 and len(x.shape) == 2:
S1 = S1.squeeze(0)
if len(S2.shape) == 3 and len(x.shape) == 2:
S2 = S2.squeeze(0)
if len(S1.shape) == 2:
x *= S1 / C
else:
x *= S1 / C
x *= S2 / C
return x.to(dtype)
else:
return None
def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor): def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor):
quant_state = linear.weight.quant_state quant_state = linear.weight.quant_state
......
...@@ -357,92 +357,6 @@ __device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadran ...@@ -357,92 +357,6 @@ __device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadran
} }
} }
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n)
{
const int tid = threadIdx.x + (blockDim.x*blockIdx.x);
const int numThreads = blockDim.x*gridDim.x;
for(int i = tid; i < n; i+=numThreads)
{
int idx = (index1[i]*maxidx1) + index2[i];
atomicAdd(&histogram[idx], src[i]);
}
}
#define THREADS_ESTIMATE 512
#define NUM_ESTIMATE 8
#define BLOCK_ESTIMATE 4096
template<typename T>
__launch_bounds__(THREADS_ESTIMATE, 1)
__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n)
{
const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE);
int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*BLOCK_ESTIMATE) : BLOCK_ESTIMATE;
const int base_idx = (blockIdx.x * BLOCK_ESTIMATE);
const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE));
T vals[NUM_ESTIMATE];
typedef cub::BlockRadixSort<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::NullType, 4, true, cub::BLOCK_SCAN_RAKING> BlockRadixSort;
typedef cub::BlockLoad<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
__shared__ union {
typename LoadFloat::TempStorage loadf;
typename BlockRadixSort::TempStorage sort;
int smem_qidx[BLOCK_ESTIMATE];
} temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_ESTIMATE)
{
valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i;
// do not process half-blocks
if(valid_items < BLOCK_ESTIMATE && n > BLOCK_ESTIMATE){ continue; }
#pragma unroll 4
for(int j = 0; j < NUM_ESTIMATE; j++)
vals[j] = max_val;
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items);
#pragma unroll 4
for(int j = 0; j < NUM_ESTIMATE; j++)
vals[j] = ((float)vals[j]) * reciprocal_num_blocks;
__syncthreads();
// sort into striped pattern to mitigate bank conflicts
// striped pattern index for thread 0 [0, 1024, 2048, 3096]
// striped pattern index for thread 1 [1, 1025, 2049, 3097]
BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals);
__syncthreads();
for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x)
temp_storage.smem_qidx[j] = -1;
__syncthreads();
if(threadIdx.x < 256)
{
float q_interval = (1.0f-(2.0f*offset))/255.0f;
int local_idx = round(((offset+(threadIdx.x*q_interval))*(valid_items-1)));
temp_storage.smem_qidx[local_idx] = threadIdx.x;
}
__syncthreads();
for(int i = threadIdx.x; i < BLOCK_ESTIMATE; i+=blockDim.x)
{
if(temp_storage.smem_qidx[i] != -1)
atomicAdd(&code[temp_storage.smem_qidx[i]], vals[i/THREADS_ESTIMATE]);
}
}
}
__launch_bounds__(TH, 4) __launch_bounds__(TH, 4)
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n) __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n)
{ {
...@@ -2998,9 +2912,6 @@ template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const ...@@ -2998,9 +2912,6 @@ template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const
template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x);
template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x);
template __global__ void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n);
template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *code, const float offset, const half max_val, const int n);
#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ #define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \ template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
float* state1, float *unorm, \ float* state1, float *unorm, \
......
...@@ -10,8 +10,6 @@ ...@@ -10,8 +10,6 @@
#define kernels #define kernels
template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n);
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n);
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); __global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
...@@ -106,10 +104,6 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi ...@@ -106,10 +104,6 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n);
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n);
template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template <int ITEMS_PER_THREAD, int THREADS>__global__ void kdequant_mm_int32_fp16( template <int ITEMS_PER_THREAD, int THREADS>__global__ void kdequant_mm_int32_fp16(
......
...@@ -18,23 +18,6 @@ using namespace BinSearch; ...@@ -18,23 +18,6 @@ using namespace BinSearch;
using std::cout; using std::cout;
using std::endl; using std::endl;
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n)
{
int threads = 512;
int num_blocks = n/threads;
num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1;
kHistogramScatterAdd2D<<<num_blocks, 512>>>(histogram, index1, index2, src, maxidx1, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n)
{
int num_blocks = n/4096;
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float)));
kEstimateQuantiles<T><<<num_blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
void quantize(float *code, float *A, unsigned char *out, int n) void quantize(float *code, float *A, unsigned char *out, int n)
{ {
...@@ -618,9 +601,6 @@ template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, cons ...@@ -618,9 +601,6 @@ template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, cons
template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
template void estimateQuantiles(half *A, float *code, float offset, int n);
template void estimateQuantiles(float *A, float *code, float offset, int n);
template void quantizeBlockwise<half, 1, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise<half, 1, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 0, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise<half, 0, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 0, FP4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise<half, 0, FP4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
......
...@@ -136,9 +136,6 @@ class ContextCusparse ...@@ -136,9 +136,6 @@ class ContextCusparse
}; };
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n);
void quantize(float *code, float *A, unsigned char *out, int n); void quantize(float *code, float *A, unsigned char *out, int n);
void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream); void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream);
template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
...@@ -165,8 +162,6 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g ...@@ -165,8 +162,6 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n); template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n);
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n);
void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
long long int strideA, long long int strideB, long long int strideC, int batchCount); long long int strideA, long long int strideB, long long int strideC, int batchCount);
......
...@@ -19,9 +19,6 @@ ...@@ -19,9 +19,6 @@
//=================================================================================== //===================================================================================
#if BUILD_CUDA #if BUILD_CUDA
void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles<float>(A, code, offset, n); }
void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); }
//void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) //void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
//{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc, 32); } //{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc, 32); }
...@@ -169,8 +166,6 @@ void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_r ...@@ -169,8 +166,6 @@ void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_r
extern "C" extern "C"
{ {
#if BUILD_CUDA #if BUILD_CUDA
void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); }
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
void cdequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream){ dequantize(code, A, out, n, stream); } void cdequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream){ dequantize(code, A, out, n, stream); }
...@@ -271,7 +266,6 @@ extern "C" ...@@ -271,7 +266,6 @@ extern "C"
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); }
void cigemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) void cigemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc)
{ gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc); } { gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc); }
......
import numpy as np
import pytest import pytest
from scipy.stats import norm
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
...@@ -9,70 +7,6 @@ from tests.helpers import BOOLEAN_TRIPLES, describe_dtype, get_test_dims, id_for ...@@ -9,70 +7,6 @@ from tests.helpers import BOOLEAN_TRIPLES, describe_dtype, get_test_dims, id_for
from tests.test_autograd import TRANSPOSE_VALS from tests.test_autograd import TRANSPOSE_VALS
@pytest.mark.deprecated
def test_kbit_quantile_estimation():
for i in range(100):
data = torch.randn(1024, 1024, device="cuda")
for bits in range(2, 9):
p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
err = torch.abs(val1 - val2).mean()
assert err < 0.038
for i in range(100):
data = torch.randn(1024, 1024, device="cuda")
for bits in range(2, 4):
total_values = 2**bits - 1
p = np.linspace(0, 1, 2 * total_values + 1)
idx = np.arange(1, 2 * total_values + 1, 2)
p = p[idx]
offset = 1 / (2 * total_values)
p = np.linspace(offset, 1 - offset, total_values)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1)
err = torch.abs(val1 - val2).mean()
assert err < 0.035
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"])
@pytest.mark.deprecated
def test_estimate_quantiles(dtype):
A = torch.rand(1024, 1024, device="cuda")
A = A.to(dtype)
code = F.estimate_quantiles(A)
percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)
A = torch.randn(1024, 1024, device="cuda")
A = A.to(dtype)
code = F.estimate_quantiles(A)
quantiles = torch.quantile(A.float(), percs)
diff = torch.abs(code - quantiles)
assert (diff > 5e-02).sum().item() == 0
@pytest.mark.deprecated
def test_quantile_quantization():
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1 - A2).mean().item()
assert diff < 0.0075
A1 = torch.rand(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
assert diff < 0.001
@pytest.mark.deprecated @pytest.mark.deprecated
def test_dynamic_quantization(): def test_dynamic_quantization():
diffs = [] diffs = []
...@@ -208,3 +142,34 @@ def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -208,3 +142,34 @@ def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
grad_err = (gradB1 - gradB2).abs().mean() grad_err = (gradB1 - gradB2).abs().mean()
assert grad_err.item() < 0.003 assert grad_err.item() < 0.003
torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
@pytest.mark.deprecated
def test_fp8linear():
b = 10
h = 1024
inp = torch.randn(b, h).cuda()
fp32 = torch.nn.Linear(h, h * 2).cuda()
fp8 = bnb.research.nn.LinearFP8Mixed(h, h * 2).cuda()
fp32b = torch.nn.Linear(h * 2, h).cuda()
fp8b = bnb.research.nn.LinearFP8Mixed(h * 2, h).cuda()
fp8.weight.data.copy_(fp32.weight.data)
fp8.bias.data.copy_(fp32.bias.data)
fp8b.weight.data.copy_(fp32b.weight.data)
fp8b.bias.data.copy_(fp32b.bias.data)
a = fp32b(torch.nn.functional.gelu(fp32(inp)))
b = fp8b(torch.nn.functional.gelu(fp8(inp)))
err = (a - b).abs().mean()
a.mean().backward()
b.mean().backward()
graderr = (fp8.weight.grad - fp32.weight.grad).abs().mean()
bgraderr = (fp8.bias.grad - fp32.bias.grad).abs().mean()
assert err < 0.05
assert graderr < 0.00002
assert bgraderr < 0.00002
...@@ -170,7 +170,7 @@ class Test8BitBlockwiseQuantizeFunctional: ...@@ -170,7 +170,7 @@ class Test8BitBlockwiseQuantizeFunctional:
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits"))
@pytest.mark.parametrize("method", ["linear", "fp8", "dynamic", "quantile"]) @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic"])
def test_few_bit_quant(self, device, bits, method): def test_few_bit_quant(self, device, bits, method):
if device in ("cpu", "xpu") and bits != 8: if device in ("cpu", "xpu") and bits != 8:
pytest.skip("CPU/XPU implementation only supports 8 bits") pytest.skip("CPU/XPU implementation only supports 8 bits")
...@@ -186,11 +186,7 @@ class Test8BitBlockwiseQuantizeFunctional: ...@@ -186,11 +186,7 @@ class Test8BitBlockwiseQuantizeFunctional:
code = F.create_fp8_map(True, ebits, pbits, bits).to(device) code = F.create_fp8_map(True, ebits, pbits, bits).to(device)
elif method == "dynamic": elif method == "dynamic":
code = F.create_dynamic_map(True, bits - 0, bits).to(device) code = F.create_dynamic_map(True, bits - 0, bits).to(device)
elif method == "quantile":
if device != "cuda":
pytest.skip("Quantile map only works on CUDA")
values = torch.randn(2048, 2048, device="cuda")
code = F.create_quantile_map(values, bits).cuda()
# for some data types we have no zero # for some data types we have no zero
# for some data types we have one zero # for some data types we have one zero
# for some data types we have two zeros # for some data types we have two zeros
...@@ -564,6 +560,30 @@ class TestIGEMMFunctional: ...@@ -564,6 +560,30 @@ class TestIGEMMFunctional:
class TestLLMInt8Functional: class TestLLMInt8Functional:
@staticmethod
def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half):
"""Reference implementation for the F.int8_mm_dequant function."""
C = 127.0
x = xq.float()
if len(S1.shape) == 3 and len(x.shape) == 2:
S1 = S1.squeeze(0)
if len(S2.shape) == 3 and len(x.shape) == 2:
S2 = S2.squeeze(0)
if len(S1.shape) == 2:
x *= S1 / C
else:
x *= S1 / C
x *= S2 / C
return x.to(dtype)
@staticmethod
def vectorwise_quant(x, dim=1):
"""Reference implementation"""
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
xq = torch.round(x * (127.0 / max1)).to(torch.int8)
return xq, max1
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2"))
...@@ -625,12 +645,12 @@ class TestLLMInt8Functional: ...@@ -625,12 +645,12 @@ class TestLLMInt8Functional:
if has_bias: if has_bias:
C1 += bias C1 += bias
A1, maxA = F.vectorwise_quant(A, dim=1) A1, maxA = self.vectorwise_quant(A, dim=1)
B1, maxB = F.vectorwise_quant(B, dim=1) B1, maxB = self.vectorwise_quant(B, dim=1)
C2 = F.int8_linear_matmul(A1, B1) C2 = F.int8_linear_matmul(A1, B1)
C4 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t()) C4 = self.vectorwise_mm_dequant(C2.float(), maxA, maxB.t())
if has_bias: if has_bias:
C4 += bias C4 += bias
...@@ -694,8 +714,8 @@ class TestLLMInt8Functional: ...@@ -694,8 +714,8 @@ class TestLLMInt8Functional:
def test_int8_double_quant(self, dim1, dim2): def test_int8_double_quant(self, dim1, dim2):
for i in range(k): for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half() A = torch.randn(dim1, dim2, device="cuda").half()
out_col1, Scol = F.vectorwise_quant(A, dim=0) out_col1, Scol = self.vectorwise_quant(A, dim=0)
out_row1, Srow = F.vectorwise_quant(A, dim=1) out_row1, Srow = self.vectorwise_quant(A, dim=1)
CA, CAt, statsA, statsAt, _ = F.int8_double_quant(A) CA, CAt, statsA, statsAt, _ = F.int8_double_quant(A)
...@@ -747,8 +767,8 @@ class TestLLMInt8Functional: ...@@ -747,8 +767,8 @@ class TestLLMInt8Functional:
C1a, stats1a, _ = F.int8_vectorwise_quant(A) C1a, stats1a, _ = F.int8_vectorwise_quant(A)
C2a, stats2a, _ = F.int8_vectorwise_quant(B) C2a, stats2a, _ = F.int8_vectorwise_quant(B)
A1, maxA = F.vectorwise_quant(A, dim=1) A1, maxA = self.vectorwise_quant(A, dim=1)
B1, maxB = F.vectorwise_quant(B, dim=1) B1, maxB = self.vectorwise_quant(B, dim=1)
torch.testing.assert_close(maxA.flatten().float(), stats1a) torch.testing.assert_close(maxA.flatten().float(), stats1a)
torch.testing.assert_close(maxB.flatten().float(), stats2a) torch.testing.assert_close(maxB.flatten().float(), stats2a)
...@@ -759,7 +779,7 @@ class TestLLMInt8Functional: ...@@ -759,7 +779,7 @@ class TestLLMInt8Functional:
C2 = F.int8_linear_matmul(A1, B1) C2 = F.int8_linear_matmul(A1, B1)
out3 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t()) out3 = self.vectorwise_mm_dequant(C2.float(), maxA, maxB.t())
err1 = torch.abs(out1 - out2).mean().item() err1 = torch.abs(out1 - out2).mean().item()
err2 = torch.abs(out1 - out3).mean().item() err2 = torch.abs(out1 - out3).mean().item()
...@@ -892,8 +912,9 @@ class TestSpMMFunctional: ...@@ -892,8 +912,9 @@ class TestSpMMFunctional:
else: else:
B = torch.randn(dim2, dim2 * 4, device="cuda").half() B = torch.randn(dim2, dim2 * 4, device="cuda").half()
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
B, SB = F.vectorwise_quant(B, quant_type="linear")
# B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8) SB = torch.abs(B).max().float()
B = torch.round(B / SB * 127).to(torch.int8)
print("") print("")
idx = torch.abs(A) >= threshold idx = torch.abs(A) >= threshold
...@@ -1368,26 +1389,3 @@ def test_normal_map_tree(): ...@@ -1368,26 +1389,3 @@ def test_normal_map_tree():
for i in idx: for i in idx:
pivots.append((values[i - 1] + values[i]) / 2) pivots.append((values[i - 1] + values[i]) / 2)
# print(pivots) # print(pivots)
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_managed():
n = 32 * 10
A = F.get_paged(n, n, dtype=torch.float32)
B = F.get_paged(n, n, dtype=torch.uint8)
B2 = F.get_paged(n, n, dtype=torch.float32)
assert A.is_paged
assert B.is_paged
assert A.page_deviceid == 0
assert B.page_deviceid == 0
F.fill(A, 17.0)
F.fill(B, 17)
F.fill(B2, 2)
assert (A == 17).sum().item() == n * n
assert (B == 17).sum().item() == n * n
C = A * B.float()
assert (C == 289).sum().item() == n * n
F._mul(A, B2)
F._mul(A, B2)
F._mul(A, B2)
assert (A == 17 * (2**3)).sum().item() == n * n
...@@ -343,37 +343,6 @@ def test_kbit_backprop(device, module): ...@@ -343,37 +343,6 @@ def test_kbit_backprop(device, module):
assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0 assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
@pytest.mark.deprecated
def test_fp8linear():
b = 10
h = 1024
inp = torch.randn(b, h).cuda()
fp32 = torch.nn.Linear(h, h * 2).cuda()
fp8 = bnb.research.nn.LinearFP8Mixed(h, h * 2).cuda()
fp32b = torch.nn.Linear(h * 2, h).cuda()
fp8b = bnb.research.nn.LinearFP8Mixed(h * 2, h).cuda()
fp8.weight.data.copy_(fp32.weight.data)
fp8.bias.data.copy_(fp32.bias.data)
fp8b.weight.data.copy_(fp32b.weight.data)
fp8b.bias.data.copy_(fp32b.bias.data)
a = fp32b(torch.nn.functional.gelu(fp32(inp)))
b = fp8b(torch.nn.functional.gelu(fp8(inp)))
err = (a - b).abs().mean()
a.mean().backward()
b.mean().backward()
graderr = (fp8.weight.grad - fp32.weight.grad).abs().mean()
bgraderr = (fp8.bias.grad - fp32.bias.grad).abs().mean()
assert err < 0.05
assert graderr < 0.00002
assert bgraderr < 0.00002
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("embedding_dim", [64, 65]) @pytest.mark.parametrize("embedding_dim", [64, 65])
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str) @pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)
......
...@@ -289,11 +289,6 @@ def test_global_config(requires_cuda, dim1, dim2, gtype): ...@@ -289,11 +289,6 @@ def test_global_config(requires_cuda, dim1, dim2, gtype):
optimizer_names_8bit = [ optimizer_names_8bit = [
# Non-blockwise optimizers are deprecated.
# "adam8bit",
# "lion8bit",
# "momentum8bit",
# "rmsprop8bit",
"adam8bit_blockwise", "adam8bit_blockwise",
"lion8bit_blockwise", "lion8bit_blockwise",
"momentum8bit_blockwise", "momentum8bit_blockwise",
...@@ -310,11 +305,9 @@ optimizer_names_8bit = [ ...@@ -310,11 +305,9 @@ optimizer_names_8bit = [
def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
torch.set_printoptions(precision=6) torch.set_printoptions(precision=6)
if gtype == torch.bfloat16 and "blockwise" not in optim_name:
pytest.skip()
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
p2 = p1.clone() p2 = p1.clone()
p1 = p1.float() p1 = p1.float()
...@@ -349,8 +342,6 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): ...@@ -349,8 +342,6 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
dequant_states = [] dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]: for name1, name2, qmap, max_val in str2statenames[optim_name]:
# print(bnb_optimizer.state[p2][max_val], name1)
if "blockwise" in optim_name:
## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1] ## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1]
## separately and then stack them. The qmap is shared, but absmax is also stacked. ## separately and then stack them. The qmap is shared, but absmax is also stacked.
if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2": if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2":
...@@ -368,7 +359,6 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): ...@@ -368,7 +359,6 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
) )
s1 = torch.stack((m1, m2)) s1 = torch.stack((m1, m2))
else: else:
s1 = F.dequantize_blockwise( s1 = F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap], code=bnb_optimizer.state[p2][qmap],
...@@ -376,12 +366,7 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): ...@@ -376,12 +366,7 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
A=bnb_optimizer.state[p2][name2], A=bnb_optimizer.state[p2][name2],
blocksize=blocksize, blocksize=blocksize,
) )
else:
s1 = F.dequantize(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
)
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0 num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
# assert num_not_close.sum().item() < 20 # assert num_not_close.sum().item() < 20
dequant_states.append(s1.clone()) dequant_states.append(s1.clone())
...@@ -414,7 +399,6 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): ...@@ -414,7 +399,6 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2]) torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2])
torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap]) torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap])
if "blockwise" in optim_name:
## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1] ## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1]
## separately and then stack them. The qmap is shared, but absmax is also stacked. ## separately and then stack them. The qmap is shared, but absmax is also stacked.
if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2": if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2":
...@@ -441,12 +425,7 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): ...@@ -441,12 +425,7 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
A=bnb_optimizer.state[p2][name2], A=bnb_optimizer.state[p2][name2],
blocksize=blocksize, blocksize=blocksize,
) )
else:
s1 = F.dequantize(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
)
torch.testing.assert_close(s1cpy, s1) torch.testing.assert_close(s1cpy, s1)
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0 num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
...@@ -463,9 +442,6 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): ...@@ -463,9 +442,6 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
torch_optimizer.state[p1][name1].copy_(s.data) torch_optimizer.state[p1][name1].copy_(s.data)
# print(sum(errors)/len(errors))
# print(sum(relerrors)/len(relerrors))
@pytest.mark.parametrize("optim_bits", [32, 8], ids=id_formatter("optim_bits")) @pytest.mark.parametrize("optim_bits", [32, 8], ids=id_formatter("optim_bits"))
@pytest.mark.parametrize("gtype", [torch.float32], ids=describe_dtype) @pytest.mark.parametrize("gtype", [torch.float32], ids=describe_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment