Commit 144fd688 authored by zhaoying1's avatar zhaoying1
Browse files

Added bitsandbytes

parent 387082e1
Pipeline #328 canceled with stages
# How to override config hyperparameters for particular weights/parameters
If you want to optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, you can use the `GlobalOptimManager`. With this, we can also configure specific hyperparameters for particular layers, such as embedding layers. To do that, we need two things: (1) register the parameter while they are still on the CPU, (2) override the config with the new desired hyperparameters (anytime, anywhere). See our [guide](howto_config_override.md) for more details
For global overrides in many different places in your code you can do:
```python
import torch
import bitsandbytes as bnb
mng = bnb.optim.GlobalOptimManager.get_instance()
model = MyModel()
mng.register_parameters(model.parameters()) # 1. register parameters while still on CPU
model = model.cuda()
# use 8-bit optimizer states for all parameters
adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8)
# 2a. override: the parameter model.fc1.weight now uses 32-bit Adam
mng.override_config(model.fc1.weight, 'optim_bits', 32)
# 2b. override: the two special layers use
# sparse optimization + different learning rate + different Adam betas
mng.override_config([model.special.weight, model.also_special.weight],
key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)})
```
Possible options for the config override are: `betas, eps, weight_decay, lr, optim_bits, min_8bit_size, percentile_clipping, block_wise, max_unorm`
For overrides for particular layers we recommend overriding locally in each module. You can do this by passing the module, the parameter, and its attribute name to the GlobalOptimManager:
```python
class MyModule(torch.nn.Module):
def __init__(din, dout):
super(MyModule, self).__init__()
self.linear = torch.nn.Linear(din, dout)
# optimization will happen in 32-bit and
# learning rate will be set to 0.0001 independent of the main learning rate
config = {'optim_bits': 32, 'lr' : 0.0001}
GlobalOptimManager.get_instance().register_module_override(self, 'weight', config)
```
#pragma once
#include "Portable.h"
namespace BinSearch {
namespace Details {
template <typename T>
bool isAligned(const T *p, size_t A)
{
return (reinterpret_cast<size_t>(p) % A) == 0;
}
template <class T, size_t A=64>
struct AlignedVec
{
AlignedVec()
: m_storage(0)
, m_data(0)
, m_sz(0)
{
}
static size_t nBytes(size_t sz)
{
return sz * sizeof(T) + A;
}
static size_t shiftAmt(char *p)
{
return A>1? (A - (reinterpret_cast<size_t>(p) % A)) % A: 0;
}
void setPtr(char *p, size_t sz)
{
m_sz = sz;
m_data = reinterpret_cast<T *>(p + shiftAmt(p));
}
//void setPtr(T *p, size_t sz)
//{
// m_sz = sz;
// if (A>1)
// myassert(((reinterpret_cast<size_t>(p) % A) == 0), "bad alignment");
// m_data = p;
//}
// internal allocation
void resize(size_t sz)
{
m_storage = new char[nBytes(sz)];
setPtr(m_storage, sz);
}
// external allocation
void set(char *storage, size_t sz)
{
setPtr(storage, sz);
}
~AlignedVec()
{
if (m_storage)
delete [] m_storage;
}
size_t size() const { return m_sz; }
T& operator[](size_t i) { return m_data[i]; }
const T& operator[](size_t i) const { return m_data[i]; }
T* begin() { return m_data; }
T* end() { return m_data+m_sz; }
const T* begin() const { return m_data; }
const T* end() const { return m_data+m_sz; }
T& front() { return m_data[0]; }
T& back() { return m_data[m_sz-1]; }
const T& front() const { return m_data[0]; }
const T& back() const { return m_data[m_sz - 1]; }
private:
char *m_storage;
T *m_data;
size_t m_sz;
};
} // namespace Details
} // namespace BinSearch
#pragma once
#include <algorithm>
#include <limits>
#include <type_traits>
#include "AAlloc.h"
namespace BinSearch {
namespace Details {
namespace DirectAux {
#define SAFETY_MULTI_PASS true
template <typename T>
struct HResults
{
HResults(T h, double ratio, size_t n) : H(h), hRatio(ratio), nInc(n) {}
T H;
double hRatio;
size_t nInc;
};
#ifdef USE_FMA
template <Algos A> struct IsDirect { static const bool value = (A == Direct) || (A == DirectFMA); };
template <Algos A> struct IsDirect2 { static const bool value = (A == Direct2) || (A == Direct2FMA); };
template <Algos A> struct IsDirectCache { static const bool value = (A == DirectCache) || (A == DirectCacheFMA); };
#else
template <Algos A> struct IsDirect { static const bool value = (A == Direct); };
template <Algos A> struct IsDirect2 { static const bool value = (A == Direct2); };
template <Algos A> struct IsDirectCache { static const bool value = (A == DirectCache); };
#endif
// general definition
template <Algos A, typename T, typename Enable = void>
struct BucketElem
{
FORCE_INLINE void set( uint32 b, const T *)
{
m_b = b;
}
FORCE_INLINE uint32 index() const { return m_b; }
private:
uint32 m_b;
};
// specialization for DirectCache methods
template <typename T> struct MatchingIntType;
template <> struct MatchingIntType<double> { typedef uint64 type; };
template <> struct MatchingIntType<float> { typedef uint32 type; };
template <Algos A, typename T>
struct BucketElem<A, T, typename std::enable_if< IsDirectCache<A>::value >::type >
{
typedef typename MatchingIntType<T>::type I;
void set(uint32 b, const T *xi)
{
u.u.x = xi[b];
u.u.b = b;
}
FORCE_INLINE I index() const { return u.u.b; }
FORCE_INLINE T x() const { return u.u.x; }
private:
union {
double dummy;
struct
{
T x;
I b;
} u;
} u;
};
template <bool UseFMA, unsigned char Gap, typename T>
struct DirectTraits
{
static void checkH(T scaler, T x0, T xN)
{
T Dn = xN - x0;
T ifmax = Dn * scaler;
myassert((ifmax < std::numeric_limits<uint32>::max() - (Gap - 1)),
"Problem unfeasible: index size exceeds uint32 capacity:"
<< " D[N] =" << Dn
<< ", H =" << scaler
<< ", H D[n] =" << ifmax << "\n"
);
}
FORCE_INLINE static uint32 f(T scaler, T x0, T z)
{
T tmp = scaler * (z - x0);
#ifdef USE_SSE2
return ftoi(FVec1<SSE,T>(tmp));
#else
return static_cast<uint32>(tmp);
#endif
}
template <InstrSet I>
FORCE_INLINE static typename FTOITraits<I, T>::vec_t f(const FVec<I, T>& scaler, const FVec<I, T>& x0, const FVec<I, T>& z)
{
return ftoi(scaler*(z-x0));
}
static T cst0(T scaler, T x0)
{
return x0;
}
};
#ifdef USE_FMA
template <unsigned char Gap, typename T>
struct DirectTraits<true,Gap,T>
{
typedef FVec1<SSE, T> fVec1;
static void checkH(T scaler, T H_Times_x0, T xN)
{
union {
typename FVec1<SSE, T>::vec_t v;
T s;
} ifmax;
ifmax.v = mulSub(fVec1(scaler), fVec1(xN), fVec1(H_Times_x0));
myassert((ifmax.s < std::numeric_limits<uint32>::max() - (Gap - 1)),
"Problem unfeasible: index size exceeds uint32 capacity:"
<< " H X[0] =" << H_Times_x0
<< ", H =" << scaler
<< ", X[N] =" << xN
<< ", H X[N] - H X[0] =" << ifmax.s << "\n"
);
}
FORCE_INLINE static uint32 f(T scaler, T Hx0, T xi)
{
return ftoi(mulSub(fVec1(scaler), fVec1(xi), fVec1(Hx0)));
}
template <InstrSet I>
FORCE_INLINE static typename FTOITraits<I,T>::vec_t f(const FVec<I,T>& scaler, const FVec<I, T>& H_Times_X0, const FVec<I, T>& z)
{
return ftoi(mulSub(scaler, z, H_Times_X0));
}
static T cst0(T scaler, T x0)
{
return scaler*x0;
}
};
#endif
template <unsigned char Gap, typename T, Algos A>
struct DirectInfo
{
static const bool UseFMA = (A == DirectFMA) || (A == Direct2FMA) || (A == DirectCacheFMA);
typedef DirectTraits<UseFMA, Gap, T> fun_t;
typedef BucketElem<A,T> bucket_t;
typedef AlignedVec<bucket_t> bucketvec_t;
struct Data {
Data() : buckets(0), xi(0), scaler(0), cst0(0) {}
Data( const T *x // for Direct must persist if xws=NULL
, uint32 n
, T H
, bucket_t *bws // assumed to gave size nb, as computed below
, T *xws = NULL // assumed to have size (n+Gap-1). Optional for Direct, unused for DirectCache, required for DirectGap
)
: buckets(bws)
, scaler(H)
, cst0(fun_t::cst0(H, x[0]))
{
myassert(((bws != NULL) && (isAligned(bws,64))), "bucket pointer not allocated or incorrectly aligned");
uint32 nb = 1 + fun_t::f(H, cst0, x[n-1]);
const uint32 npad = Gap-1;
const uint32 n_sz = n + npad; // size of padded vector
if (xws) {
myassert(isAligned(xws,8), "x pointer not allocated or incorrectly aligned");
std::fill_n(xws, npad, x[0]); // pad in front with x[0]
std::copy(x, x+n, xws + npad);
xi = xws;
}
else {
myassert(Gap==1, "if Gap>1 then X workspace must be provided");
xi = x;
}
populateIndex(bws, nb, xi, n_sz, scaler, cst0);
}
const bucket_t *buckets;
const T *xi;
T scaler;
T cst0; // could be x0 or (scaler*x0), depending if we are using FMA or not
} data;
static T growStep(T H)
{
T step;
T P = next(H);
while ((step = P - H) == 0)
P = next(P);
return step;
}
static HResults<T> computeH(const T *px, uint32 nx)
{
myassert((nx > Gap), "Array X too small");
myassert(((Gap == 1) || (Gap == 2)), "Only tested for these values of Gap");
const T x0 = px[0];
const T xN = px[nx-1];
const T range = xN - x0;
myassert((range < std::numeric_limits<T>::max()), "range too large");
// check that D_i are strictly increasing and compute minimum value D_{i+Offset}-D_i
T deltaDMin = range;
for (uint32 i = Gap; i < nx; ++i) {
T Dnew = px[i] - x0;
T Dold = px[i - Gap] - x0;
myassert((Dnew > Dold),
"Problem unfeasible: D_i sequence not strictly increasing"
<< " X[" << 0 << "]=" << x0
<< " X[" << i - Gap << "]=" << px[i - Gap]
<< " X[" << i << "]=" << px[i]
<< "\n"
);
T deltaD = Dnew - Dold;
if (deltaD < deltaDMin)
deltaDMin = deltaD;
}
// initial guess for H
const T H0 = T(1.0) / deltaDMin;
T H = H0;
T cst0 = fun_t::cst0(H, x0);
fun_t::checkH(H, cst0, xN);
// adjust H by trial and error until succeed
size_t nInc = 0;
bool modified = false;
size_t npasses = 0;
T step = growStep(H);
uint32 seg_already_checked_from = nx;
do {
myassert((npasses++ < 2), "verification failed\n");
// if there has been an increase, then check only up to that point
uint32 last_seg_to_be_checked = seg_already_checked_from - 1;
modified = false;
uint32 inew = 0;
for (uint32 i = Gap; i <= last_seg_to_be_checked; ++i) {
uint32 iold = fun_t::f(H, cst0, px[i-Gap]);
uint32 inew = fun_t::f(H, cst0, px[i]);
while (inew == iold) {
seg_already_checked_from = i;
last_seg_to_be_checked = nx-1; // everything needs to be checked
modified = true;
H = H + step;
step *= 2;
// recalculate all constants and indices
cst0 = fun_t::cst0(H, x0);
fun_t::checkH(H, cst0, xN);
iold = fun_t::f(H, cst0, px[i - Gap]);
inew = fun_t::f(H, cst0, px[i]);
}
}
} while (SAFETY_MULTI_PASS && modified);
return HResults<T>(H, (((double)H) / H0) - 1.0, nInc);
}
static void populateIndex(BucketElem<A, T> *buckets, uint32 index_size, const T *px, uint32 x_size, T scaler, T cst0)
{
for (uint32 i = x_size-1, b = index_size-1, j=0; ; --i) {
uint32 idx = fun_t::f(scaler, cst0, px[i]);
while (b > idx) { // in the 1st iteration it is j=0 but this condition is always false
buckets[b].set( j, px );
--b;
}
if (Gap==1 || b == idx) { // if Gap==1, which is known at compile time, the check b==idx is redundant
j = i - (Gap-1); // subtracting (Gap-1) points to the index of the first X-element to check
buckets[b].set(j, px);
if (b-- == 0)
break;
}
}
}
DirectInfo(const Data& d)
: data(d)
{
}
DirectInfo(const T* px, const uint32 n)
{
HResults<T> res = computeH(px, n);
#ifdef PAPER_TEST
nInc = res.nInc;
hRatio = res.hRatio;
#endif
const uint32 npad = Gap-1;
const uint32 n_sz = n + npad; // size of padded vector
if (npad)
xi.resize(n_sz);
T H = res.H;
T cst0 = fun_t::cst0(H, px[0]);
const uint32 maxIndex = fun_t::f(H, cst0, px[n-1]);
buckets.resize(maxIndex + 1);
data = Data(px, n, H, buckets.begin(), (npad? xi.begin(): NULL));
}
private:
bucketvec_t buckets;
AlignedVec<T,8> xi;
#ifdef PAPER_TEST
public:
double hRatio;
size_t nInc;
#endif
};
} // namespace DirectAux
} // namespace Details
} // namespace BinSearch
#pragma once
#include "Algo-Direct-Common.h"
namespace BinSearch
{
namespace Details
{
template <typename T, Algos A>
struct AlgoScalarBase<T, A, typename std::enable_if<DirectAux::IsDirect2<A>::value>::type> : DirectAux::DirectInfo<2, T, A>
{
private:
typedef DirectAux::DirectInfo<2, T, A> base_t;
static const size_t Offset = 2;
public:
AlgoScalarBase(const T *x, const uint32 n)
: base_t(x, n)
{
}
FORCE_INLINE uint32 scalar(T z) const
{
const T *px = base_t::data.xi;
const uint32 *buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
uint32 bidx = base_t::fun_t::f(base_t::data.scaler, base_t::data.cst0, z);
uint32 iidx = buckets[bidx];
px += iidx;
if (z < *px)
--iidx;
if (z < *(px + 1))
--iidx;
return iidx;
}
};
template <InstrSet I, typename T, Algos A>
struct AlgoVecBase<I, T, A, typename std::enable_if<DirectAux::IsDirect2<A>::value>::type> : AlgoScalarBase<T, A>
{
static const uint32 nElem = sizeof(typename InstrFloatTraits<I, T>::vec_t) / sizeof(T);
typedef FVec<I, T> fVec;
typedef IVec<SSE, T> i128;
struct Constants
{
fVec vscaler;
fVec vcst0;
IVec<I, T> one;
};
private:
typedef AlgoScalarBase<T, A> base_t;
FORCE_INLINE
// NO_INLINE
void resolve(const FVec<SSE, float> &vz, const IVec<SSE, float> &bidx, uint32 *pr) const
{
union U
{
__m128i vec;
uint32 ui32[4];
} u;
const uint32 *buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
const float *xi = base_t::data.xi;
// read indices t
const double *p3 = reinterpret_cast<const double *>(&xi[(u.ui32[3] = buckets[bidx.get3()])]);
const double *p2 = reinterpret_cast<const double *>(&xi[(u.ui32[2] = buckets[bidx.get2()])]);
const double *p1 = reinterpret_cast<const double *>(&xi[(u.ui32[1] = buckets[bidx.get1()])]);
const double *p0 = reinterpret_cast<const double *>(&xi[(u.ui32[0] = buckets[bidx.get0()])]);
#if 0
// read pairs ( X(t-1), X(t) )
__m128 xp3 = _mm_castpd_ps(_mm_load_sd(p3));
__m128 xp2 = _mm_castpd_ps(_mm_load_sd(p2));
__m128 xp1 = _mm_castpd_ps(_mm_load_sd(p1));
__m128 xp0 = _mm_castpd_ps(_mm_load_sd(p0));
// build:
// { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) }
// { X(t(0)), X(t(1)), X(t(2)), X(t(3)) }
__m128 h13 = _mm_shuffle_ps(xp1, xp3, (1 << 2) + (1 << 6));
__m128 h02 = _mm_shuffle_ps(xp0, xp2, (1 << 2) + (1 << 6));
__m128 u01 = _mm_unpacklo_ps(h02, h13);
__m128 u23 = _mm_unpackhi_ps(h02, h13);
__m128 vxm = _mm_shuffle_ps(u01, u23, (0) + (1 << 2) + (0 << 4) + (1 << 6));
__m128 vxp = _mm_shuffle_ps(u01, u23, (2) + (3 << 2) + (2 << 4) + (3 << 6));
#else
__m128 xp23 = _mm_castpd_ps(_mm_set_pd(*p3, *p2));
__m128 xp01 = _mm_castpd_ps(_mm_set_pd(*p1, *p0));
__m128 vxm = _mm_shuffle_ps(xp01, xp23, (0) + (2 << 2) + (0 << 4) + (2 << 6));
__m128 vxp = _mm_shuffle_ps(xp01, xp23, (1) + (3 << 2) + (1 << 4) + (3 << 6));
#endif
IVec<SSE, float> i(u.vec);
IVec<SSE, float> vlem = (vz < vxm);
IVec<SSE, float> vlep = (vz < vxp);
i = i + vlem + vlep;
i.store(pr);
}
FORCE_INLINE
// NO_INLINE
void resolve(const FVec<SSE, double> &vz, const IVec<SSE, float> &bidx, uint32 *pr) const
{
const uint32 *buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
const double *xi = base_t::data.xi;
uint32 b1 = buckets[bidx.get1()];
uint32 b0 = buckets[bidx.get0()];
const double *p1 = &xi[b1];
const double *p0 = &xi[b0];
// read pairs ( X(t-1), X(t) )
__m128d vx1 = _mm_loadu_pd(p1);
__m128d vx0 = _mm_loadu_pd(p0);
// build:
// { X(t(0)-1), X(t(1)-1) }
// { X(t(0)), X(t(1)) }
__m128d vxm = _mm_shuffle_pd(vx0, vx1, 0);
__m128d vxp = _mm_shuffle_pd(vx0, vx1, 3);
IVec<SSE, double> i(b1, b0);
IVec<SSE, double> vlem = (vz < vxm);
IVec<SSE, double> vlep = (vz < vxp);
i = i + vlem + vlep;
union
{
__m128i vec;
uint32 ui32[4];
} u;
u.vec = i;
pr[0] = u.ui32[0];
pr[1] = u.ui32[2];
}
#ifdef USE_AVX
FORCE_INLINE
// NO_INLINE
void resolve(const FVec<AVX, float> &vz, const IVec<AVX, float> &bidx, uint32 *pr) const
{
const uint32 *buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
const float *xi = base_t::data.xi;
#if 0 // use gather instructions
IVec<AVX,float> idxm;
idxm.setidx(buckets, bidx);
__m256i z = _mm256_setzero_si256();
IVec<AVX,float> minusone = _mm256_cmpeq_epi32(z,z);
IVec<AVX,float> idxp = idxm - minusone;
FVec<AVX, float> vxm = _mm256_i32gather_ps(xi, idxm, sizeof(float));
FVec<AVX, float> vxp = _mm256_i32gather_ps(xi, idxp, sizeof(float));
IVec<AVX, float> ip = idxm;
#else // do not use gather instrucions
union U
{
__m256i vec;
uint32 ui32[8];
} u;
// read indices t
const double *p7 = reinterpret_cast<const double *>(&xi[(u.ui32[7] = buckets[bidx.get7()])]);
const double *p6 = reinterpret_cast<const double *>(&xi[(u.ui32[6] = buckets[bidx.get6()])]);
const double *p5 = reinterpret_cast<const double *>(&xi[(u.ui32[5] = buckets[bidx.get5()])]);
const double *p4 = reinterpret_cast<const double *>(&xi[(u.ui32[4] = buckets[bidx.get4()])]);
const double *p3 = reinterpret_cast<const double *>(&xi[(u.ui32[3] = buckets[bidx.get3()])]);
const double *p2 = reinterpret_cast<const double *>(&xi[(u.ui32[2] = buckets[bidx.get2()])]);
const double *p1 = reinterpret_cast<const double *>(&xi[(u.ui32[1] = buckets[bidx.get1()])]);
const double *p0 = reinterpret_cast<const double *>(&xi[(u.ui32[0] = buckets[bidx.get0()])]);
#if 0 // perform 8 loads in double precision
// read pairs ( X(t-1), X(t) )
__m128 xp7 = _mm_castpd_ps(_mm_load_sd(p7));
__m128 xp6 = _mm_castpd_ps(_mm_load_sd(p6));
__m128 xp5 = _mm_castpd_ps(_mm_load_sd(p5));
__m128 xp4 = _mm_castpd_ps(_mm_load_sd(p4));
__m128 xp3 = _mm_castpd_ps(_mm_load_sd(p3));
__m128 xp2 = _mm_castpd_ps(_mm_load_sd(p2));
__m128 xp1 = _mm_castpd_ps(_mm_load_sd(p1));
__m128 xp0 = _mm_castpd_ps(_mm_load_sd(p0));
// build:
// { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) }
// { X(t(0)), X(t(1)), X(t(2)), X(t(3)) }
__m128 h57 = _mm_shuffle_ps(xp5, xp7, (1 << 2) + (1 << 6)); // F- F+ H- H+
__m128 h46 = _mm_shuffle_ps(xp4, xp6, (1 << 2) + (1 << 6)); // E- E+ G- G+
__m128 h13 = _mm_shuffle_ps(xp1, xp3, (1 << 2) + (1 << 6)); // B- B+ D- D+
__m128 h02 = _mm_shuffle_ps(xp0, xp2, (1 << 2) + (1 << 6)); // A- A+ C- C+
__m128 u01 = _mm_unpacklo_ps(h02, h13); // A- B- A+ B+
__m128 u23 = _mm_unpackhi_ps(h02, h13); // C- D- C+ D+
__m128 u45 = _mm_unpacklo_ps(h46, h57); // E- F- E+ F+
__m128 u67 = _mm_unpackhi_ps(h46, h57); // G- H- G+ H+
__m128 abcdm = _mm_shuffle_ps(u01, u23, (0) + (1 << 2) + (0 << 4) + (1 << 6)); // A- B- C- D-
__m128 abcdp = _mm_shuffle_ps(u01, u23, (2) + (3 << 2) + (2 << 4) + (3 << 6)); // A+ B+ C+ D+
__m128 efghm = _mm_shuffle_ps(u45, u67, (0) + (1 << 2) + (0 << 4) + (1 << 6)); // E- F- G- H-
__m128 efghp = _mm_shuffle_ps(u45, u67, (2) + (3 << 2) + (2 << 4) + (3 << 6)); // E+ F+ G+ H+
FVec<AVX, float> vxp = _mm256_insertf128_ps(_mm256_castps128_ps256(abcdm), efghm, 1);
FVec<AVX, float> vxm = _mm256_insertf128_ps(_mm256_castps128_ps256(abcdp), efghp, 1);
IVec<AVX, float> ip(u.vec);
#else // use __mm256_set_pd
// read pairs ( X(t-1), X(t) )
__m256 x0145 = _mm256_castpd_ps(_mm256_set_pd(*p5, *p4, *p1, *p0)); // { x0(t-1), x0(t), x1(t-1), x1(t), x4(t-1), x4(t), x5(t-1), x5(t) }
__m256 x2367 = _mm256_castpd_ps(_mm256_set_pd(*p7, *p6, *p3, *p2)); // { x2(t-1), x2(t), x3(t-1), x3(t), x6(t-1), x6(t), x7(t-1), x7(t) }
// { x0(t-1), x1(t-1), x2(t-1), 3(t-1, x4(t-1), x5(t-1), x6(t-1), xt(t-1) }
FVec<AVX, float> vxm = _mm256_shuffle_ps(x0145, x2367, 0 + (2 << 2) + (0 << 4) + (2 << 6));
// { x0(t), x1(t), x2(t), 3(t, x4(t), x5(t), x6(t), xt(t) }
FVec<AVX, float> vxp = _mm256_shuffle_ps(x0145, x2367, 1 + (3 << 2) + (1 << 4) + (3 << 6));
IVec<AVX, float> ip(u.vec);
#endif
#endif
IVec<AVX, float> vlem = vz < vxm;
IVec<AVX, float> vlep = vz < vxp;
ip = ip + vlem + vlep;
ip.store(pr);
}
FORCE_INLINE
// NO_INLINE
void resolve(const FVec<AVX, double> &vz, const IVec<SSE, float> &bidx, uint32 *pr) const
{
union
{
__m256i vec;
uint64 ui64[4];
} u;
const uint32 *buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
const double *xi = base_t::data.xi;
// read indices t
const double *p3 = &xi[(u.ui64[3] = buckets[bidx.get3()])];
const double *p2 = &xi[(u.ui64[2] = buckets[bidx.get2()])];
const double *p1 = &xi[(u.ui64[1] = buckets[bidx.get1()])];
const double *p0 = &xi[(u.ui64[0] = buckets[bidx.get0()])];
// read pairs ( X(t-1), X(t) )
__m128d xp3 = _mm_loadu_pd(p3);
__m128d xp2 = _mm_loadu_pd(p2);
__m128d xp1 = _mm_loadu_pd(p1);
__m128d xp0 = _mm_loadu_pd(p0);
// build:
// { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) }
// { X(t(0)), X(t(1)), X(t(2)), X(t(3)) }
__m256d x02 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp0), xp2, 1);
__m256d x13 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp1), xp3, 1);
FVec<AVX, double> vxm = _mm256_unpacklo_pd(x02, x13);
FVec<AVX, double> vxp = _mm256_unpackhi_pd(x02, x13);
// __m128d h01m = _mm_shuffle_pd(xp0, xp1, 0);
// __m128d h23m = _mm_shuffle_pd(xp2, xp3, 0);
// __m128d h01p = _mm_shuffle_pd(xp0, xp1, 3);
// __m128d h23p = _mm_shuffle_pd(xp2, xp3, 3);
// FVec<AVX, double> vxm = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01m), h23m, 1);
// FVec<AVX, double> vxp = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01p), h23p, 1);
IVec<AVX, double> i(u.vec);
IVec<AVX, double> vlem = vz < vxm;
IVec<AVX, double> vlep = vz < vxp;
i = i + vlem + vlep;
i.extractLo32s().store(pr);
}
#endif
public:
AlgoVecBase(const T *x, const uint32 n) : base_t(x, n) {}
void initConstants(Constants &cst) const
{
cst.vscaler.setN(base_t::data.scaler);
cst.vcst0.setN(base_t::data.cst0);
cst.one.setN(uint32(1));
}
void vectorial(uint32 *pr, const T *pz, const Constants &cst) const
{
fVec vz(pz);
resolve(vz, base_t::fun_t::f(cst.vscaler, cst.vcst0, vz), pr);
}
};
} // namespace Details
} // namespace BinSearch
ALGOENUM(DirectCacheFMA, 5)
ALGOENUM(DirectFMA, 15)
ALGOENUM(Direct2FMA, 25)
ALGOENUM(DirectCache, 10)
ALGOENUM(Direct, 20)
ALGOENUM(Direct2, 30)
ALGOENUM(Nonary, 40)
ALGOENUM(Pentary, 50)
ALGOENUM(Ternary, 60)
ALGOENUM(Eytzinger, 70)
ALGOENUM(BitSet, 80)
ALGOENUM(ClassicOffset, 90)
#ifdef PAPER_TEST
ALGOENUM(MorinOffset, 100)
ALGOENUM(BitSetNoPad, 110)
ALGOENUM(ClassicMod, 120)
ALGOENUM(MorinBranchy, 130)
ALGOENUM(Classic, 140)
ALGOENUM(LowerBound, 145)
#ifdef USE_MKL
ALGOENUM(MKL, 150)
#endif
#endif
#pragma once
#include "Type.h"
#include <algorithm>
namespace BinSearch {
template <InstrSet I, typename T, Algos A, bool L=false, bool R=false>
struct BinAlgo : Details::BinAlgoBase<I,T,A>
{
typedef Details::BinAlgoBase<I,T,A> base_t;
BinAlgo(const T* px, const uint32 n) : base_t(px, n), x0(px[0]), xN(px[n-1]), N(n) {}
BinAlgo(const T* px, const uint32 n, const typename base_t::Data& d) : base_t(d), x0(px[0]), xN(px[n-1]), N(n) {}
FORCE_INLINE
uint32 scalar(T z) const
{
if (!L || z >= x0)
if (!R || z < xN)
return base_t::scalar(z);
else
return N;
else
return std::numeric_limits<uint32>::max();
}
FORCE_INLINE
void vectorial(uint32 *pr, const T *pz, uint32 n) const
{
if (!L && !R) {
Details::Loop<T,base_t>::loop(*this, pr, pz, n);
}
else {
const uint32 nElem = base_t::nElem;
const uint32 idealbufsize = 256;
const uint32 bufsize = nElem * (idealbufsize / nElem + ((idealbufsize % nElem) ? 1 : 0));
T databuf[bufsize];
uint32 resbuf[bufsize];
uint32 indexbuf[bufsize];
uint32 *prend = pr + n;
while(pr != prend) {
uint32 cnt = 0;
uint32 niter = std::min(bufsize, (uint32)std::distance(pr,prend));
for (uint32 j = 0; j < niter; ++j) {
T z = pz[j];
// FIXME: use SSE2?
if (!L || z >= x0)
if (!R || z < xN) {
databuf[cnt] = z;
indexbuf[cnt] = j;
++cnt;
}
else
pr[j] = N;
else
pr[j] = std::numeric_limits<uint32>::max();
}
// FIXME: merge these two loops
Details::Loop<T,base_t>::loop(*this, resbuf, databuf, cnt);
for (uint32 j = 0; j < cnt; ++j)
pr[indexbuf[j]] = resbuf[j];
pr += niter;
pz += niter;
}
}
}
Details::CondData<T,L> x0;
Details::CondData<T,R> xN;
Details::CondData<uint32,R> N;
};
} // namespace BinSearch
#pragma once
#include "AAlloc.h"
#include "BinAlgo.h"
#include "SIMD.h"
#include <algorithm>
#include <limits>
#include "Algo-Direct2.h"
#pragma once
#include <limits>
#include <cmath>
#include <stdexcept>
#include <sstream>
#ifdef __FMA__
#define USE_FMA
#endif
#ifdef __AVX2__
#define USE_AVX2
#endif
#ifdef __AVX__
#define USE_AVX
#endif
#ifdef __SSE4_1__
#define USE_SSE41
#endif
#ifdef __SSE4_2__
#define USE_SSE42
#endif
#ifndef _MSC_VER
#include <stdint.h>
#endif
namespace BinSearch {
#ifndef _MSC_VER
typedef int8_t int8;
typedef uint8_t uint8;
typedef int32_t int32;
typedef uint32_t uint32;
typedef int64_t int64;
typedef uint64_t uint64;
#else
typedef __int8 int8;
typedef unsigned __int8 uint8;
typedef __int32 int32;
typedef unsigned __int32 uint32;
typedef __int64 int64;
typedef unsigned __int64 uint64;
#endif
namespace Details {
#define myassert(cond, msg) if (!cond){ std::ostringstream os; os << "\nassertion failed: " << #cond << ", " << msg << "\n"; throw std::invalid_argument(os.str()); }
// log2 is not defined in VS2008
#if defined(_MSC_VER)
inline uint32 log2 (uint32 val) {
if (val == 1) return 0;
uint32 ret = 0;
do {
ret++;
val >>= 1;
} while (val > 1);
return ret;
}
#endif
#ifdef _DEBUG
#define DEBUG
#endif
#ifdef _MSC_VER
# define FORCE_INLINE __forceinline
# define NO_INLINE __declspec(noinline)
#else
# define NO_INLINE __attribute__((noinline))
# ifdef DEBUG
# define FORCE_INLINE NO_INLINE
# else
# define FORCE_INLINE __attribute__((always_inline)) inline
# endif
#endif
#ifdef USE_AVX
#define COMISS "vcomiss"
#define COMISD "vcomisd"
#else
#define COMISS "comiss"
#define COMISD "comisd"
#endif
// nextafter is not defined in VS2008
#if defined(_MSC_VER) && (_MSC_VER <= 1500)
#include <float.h>
inline float mynext(float x)
{
return _nextafterf(x, std::numeric_limits<float>::max());
}
inline double mynext(double x)
{
return _nextafter(x, std::numeric_limits<double>::max());
}
inline float myprev(float x)
{
return _nextafterf(x, -std::numeric_limits<float>::max());
}
inline double myprev(double x)
{
return _nextafter(x, -std::numeric_limits<double>::max());
}
#else
inline float mynext(float x)
{
return std::nextafterf(x, std::numeric_limits<float>::max());
}
inline double mynext(double x)
{
return std::nextafter(x, std::numeric_limits<double>::max());
}
inline float myprev(float x)
{
return std::nextafterf(x, -std::numeric_limits<float>::max());
}
inline double myprev(double x)
{
return std::nextafter(x, -std::numeric_limits<double>::max());
}
#endif
template <typename T>
inline T next(T x)
{
for (int i = 0; i < 4; ++i)
x = mynext(x);
return x;
}
template <typename T>
inline T prev(T x)
{
for (int i = 0; i < 4; ++i)
x = myprev(x);
return x;
}
} // namepsace Details
} // namespace BinSearch
#pragma once
#include "Portable.h"
#ifdef USE_SSE42
#ifndef _MSC_VER
#include <popcntintrin.h>
#define popcnt32 _mm_popcnt_u32
#else
#include <intrin.h>
#define popcnt32 __popcnt
#endif
#else // USE_SSE42
namespace BinSearch {
FORCE_INLINE int popcnt32(int x32)
{
// strictly speaking this is not correct, as it ignores higher order bits
// however, this is only used on the resuot of movemask on a 128-bit register, which is 8 at most, so it is ok
// with 256-bit registers, SSE42 is defined, and we do not use this function
uint8 x = static_cast<uint8>(x32);
x = (x & 0x55) + (x >> 1 & 0x55);
x = (x & 0x33) + (x >> 2 & 0x33);
x = (x & 0x0f) + (x >> 4 & 0x0f);
return x;
}
} // namespace
#endif
#if defined(USE_AVX) || defined(USE_AVX2)
#include <immintrin.h>
#else
#include <emmintrin.h>
#ifdef USE_SSE41
#include <smmintrin.h>
#endif
#endif
#include "Type.h"
namespace BinSearch {
namespace Details {
template <InstrSet I, class T>
struct FVec;
template <InstrSet I, class T>
struct IVec;
template <InstrSet I, class T>
struct FVec1;
template <> struct InstrIntTraits<SSE>
{
typedef __m128i vec_t;
};
template <> struct InstrFloatTraits<SSE, float>
{
typedef __m128 vec_t;
};
template <> struct InstrFloatTraits<SSE, double>
{
typedef __m128d vec_t;
};
template <InstrSet I, typename T>
struct FTOITraits
{
typedef IVec<SSE, float> vec_t;
};
#ifdef USE_AVX
template <>
struct FTOITraits<AVX, float>
{
typedef IVec<AVX, float> vec_t;
};
template <> struct InstrIntTraits<AVX>
{
typedef __m256i vec_t;
};
template <> struct InstrFloatTraits<AVX, float>
{
typedef __m256 vec_t;
};
template <> struct InstrFloatTraits<AVX, double>
{
typedef __m256d vec_t;
};
#endif
template <typename TR>
struct VecStorage
{
typedef typename TR::vec_t vec_t;
FORCE_INLINE operator vec_t&() { return vec; }
FORCE_INLINE operator const vec_t&() const { return vec; }
protected:
FORCE_INLINE VecStorage() {}
FORCE_INLINE VecStorage(const vec_t& v) : vec( v ) {}
vec_t vec;
};
template <InstrSet>
struct IVecBase;
template <>
struct IVecBase<SSE> : VecStorage<InstrIntTraits<SSE>>
{
protected:
FORCE_INLINE IVecBase() {}
FORCE_INLINE IVecBase( const vec_t& v) : VecStorage<InstrIntTraits<SSE>>( v ) {}
public:
FORCE_INLINE static vec_t zero() { return _mm_setzero_si128(); }
FORCE_INLINE int32 get0() const { return _mm_cvtsi128_si32( vec ); }
FORCE_INLINE void assignIf( const vec_t& val, const vec_t& mask )
{
#ifdef USE_SSE41
vec = _mm_blendv_epi8(vec, val, mask);
#else
vec = _mm_or_si128(_mm_andnot_si128(mask,vec), _mm_and_si128(mask,val));
#endif
}
FORCE_INLINE void orIf(const vec_t& val, const vec_t& mask)
{
vec = _mm_or_si128(vec, _mm_and_si128(val,mask));
}
};
template <>
struct IVec<SSE, float> : IVecBase<SSE>
{
FORCE_INLINE IVec() {}
FORCE_INLINE IVec( int32 i ) : IVecBase<SSE>( _mm_set1_epi32( i ) ) {}
FORCE_INLINE IVec( const vec_t& v) : IVecBase<SSE>( v ) {}
FORCE_INLINE IVec( uint32 u3, uint32 u2, uint32 u1, uint32 u0) : IVecBase<SSE>( _mm_set_epi32( u3, u2, u1, u0 ) ) {}
void setN( int32 i ) { vec = _mm_set1_epi32( i ); }
#ifdef USE_SSE41
FORCE_INLINE int32 get1() const { return _mm_extract_epi32(vec, 1); }
FORCE_INLINE int32 get2() const { return _mm_extract_epi32(vec, 2); }
FORCE_INLINE int32 get3() const { return _mm_extract_epi32(vec, 3); }
#else
FORCE_INLINE int32 get1() const { return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 1 ) ); }
FORCE_INLINE int32 get2() const { return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 2 ) ); }
FORCE_INLINE int32 get3() const { return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 3 ) ); }
#endif
FORCE_INLINE void store( uint32 *pi ) const { _mm_storeu_si128( reinterpret_cast<vec_t*>(pi), vec ); }
FORCE_INLINE int countbit()
{
return popcnt32(_mm_movemask_ps(_mm_castsi128_ps(vec)));
}
};
template <>
struct IVec<SSE, double> : IVecBase<SSE>
{
FORCE_INLINE IVec() {}
FORCE_INLINE IVec( int32 i ) : IVecBase<SSE>( _mm_set1_epi64x( i ) ) {}
FORCE_INLINE IVec( const vec_t& v) : IVecBase<SSE>( v ) {}
FORCE_INLINE IVec( uint64 u1, uint64 u0 ) : IVecBase<SSE>( _mm_set_epi64x(u1, u0) ) {}
void setN( int32 i ) { vec = _mm_set1_epi64x( i ); }
FORCE_INLINE int32 get1() const
{
#ifdef USE_SSE41
return _mm_extract_epi32(vec, 2);
#else
return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 2 ) );
#endif
}
// extract the 2 32 bits integers no. 0, 2 and store them in a __m128i
FORCE_INLINE IVec<SSE,float> extractLo32s() const
{
return _mm_shuffle_epi32(vec, ((2 << 2) | 0));
}
FORCE_INLINE void store( uint32 *pi ) const
{
pi[0] = get0();
pi[1] = get1();
}
FORCE_INLINE int countbit()
{
#if 1
// takes 4 cycles
__m128i hi = _mm_shuffle_epi32(vec, 2); // 1 cycle
__m128i s = _mm_add_epi32(vec, hi);
int32 x = _mm_cvtsi128_si32(s);
return -x;
#else
// takes 6 cycles
return popcnt32(_mm_movemask_pd(_mm_castsi128_pd(vec)));
#endif
}
};
template <typename T>
FORCE_INLINE IVec<SSE,T> operator>> (const IVec<SSE,T>& a, unsigned n) { return _mm_srli_epi32(a, n); }
template <typename T>
FORCE_INLINE IVec<SSE,T> operator<< (const IVec<SSE,T>& a, unsigned n) { return _mm_slli_epi32(a, n); }
template <typename T>
FORCE_INLINE IVec<SSE,T> operator& (const IVec<SSE,T>& a, const IVec<SSE,T>& b ) { return _mm_and_si128( a, b ); }
template <typename T>
FORCE_INLINE IVec<SSE,T> operator| (const IVec<SSE,T>& a, const IVec<SSE,T>& b ) { return _mm_or_si128( a, b ); }
template <typename T>
FORCE_INLINE IVec<SSE,T> operator^ (const IVec<SSE,T>& a, const IVec<SSE,T>& b ) { return _mm_xor_si128( a, b ); }
template <typename T>
FORCE_INLINE IVec<SSE,T> operator+ (const IVec<SSE,T>& a, const IVec<SSE,T>& b ) { return _mm_add_epi32( a, b ); }
template <typename T>
FORCE_INLINE IVec<SSE,T> operator- (const IVec<SSE,T>& a, const IVec<SSE,T>& b ) { return _mm_sub_epi32( a, b ); }
#ifdef USE_SSE41
template <typename T>
FORCE_INLINE IVec<SSE,T> min (const IVec<SSE,T>& a, const IVec<SSE,T>& b ) { return _mm_min_epi32( a, b ); }
#endif
typedef VecStorage<InstrFloatTraits<SSE,float>> FVec128Float;
template <>
struct FVec1<SSE, float> : FVec128Float
{
FORCE_INLINE FVec1() {}
FORCE_INLINE FVec1( float f ) : FVec128Float( _mm_load_ss( &f ) ) {}
FORCE_INLINE FVec1( const vec_t& v ): FVec128Float( v ) {}
FORCE_INLINE float get0() const { return _mm_cvtss_f32( vec ); }
};
template <>
struct FVec<SSE, float> : FVec128Float
{
FORCE_INLINE FVec() {}
FORCE_INLINE FVec( float f ) : FVec128Float( _mm_set1_ps( f ) ) {}
FORCE_INLINE FVec( const float *v ) : FVec128Float( _mm_loadu_ps( v ) ) {}
FORCE_INLINE FVec( const vec_t& v) : FVec128Float(v) {}
FORCE_INLINE FVec( float f3, float f2, float f1, float f0 ) : FVec128Float( _mm_set_ps(f3, f2, f1, f0) ) {}
void set0( float f ) { vec = _mm_load_ss( &f ); }
void setN( float f ) { vec = _mm_set1_ps( f ); }
FORCE_INLINE void setidx( const float *xi, const IVec<SSE,float>& idx )
{
uint32 i0 = idx.get0();
uint32 i1 = idx.get1();
uint32 i2 = idx.get2();
uint32 i3 = idx.get3();
vec = _mm_set_ps( xi[i3], xi[i2], xi[i1], xi[i0] );
}
FORCE_INLINE float get0() const { return _mm_cvtss_f32( vec ); }
FORCE_INLINE float get1() const { return _mm_cvtss_f32( _mm_shuffle_ps( vec, vec, 1 ) ); }
FORCE_INLINE float get2() const { return _mm_cvtss_f32( _mm_shuffle_ps( vec, vec, 2 ) ); }
FORCE_INLINE float get3() const { return _mm_cvtss_f32( _mm_shuffle_ps( vec, vec, 3 ) ); }
};
FORCE_INLINE FVec1<SSE,float> operator+ (const FVec1<SSE,float>& a, const FVec1<SSE,float>& b) { return _mm_add_ss( a, b ); }
FORCE_INLINE FVec1<SSE,float> operator- (const FVec1<SSE,float>& a, const FVec1<SSE,float>& b) { return _mm_sub_ss( a, b ); }
FORCE_INLINE FVec1<SSE,float> operator* (const FVec1<SSE,float>& a, const FVec1<SSE,float>& b) { return _mm_mul_ss( a, b ); }
FORCE_INLINE FVec1<SSE,float> operator/ (const FVec1<SSE,float>& a, const FVec1<SSE,float>& b) { return _mm_div_ss( a, b ); }
FORCE_INLINE int ftoi (const FVec1<SSE,float>& a) { return _mm_cvttss_si32(a); }
FORCE_INLINE IVec<SSE,float> operator> (const FVec1<SSE,float>& a, const FVec1<SSE,float>& b) { return _mm_castps_si128( _mm_cmpgt_ss( a, b ) ); }
#ifdef USE_FMA
FORCE_INLINE FVec1<SSE, float> mulSub(const FVec1<SSE, float>& a, const FVec1<SSE, float>& b, const FVec1<SSE, float>& c) { return _mm_fmsub_ss(a, b, c); }
#endif
FORCE_INLINE FVec<SSE,float> operator- (const FVec<SSE,float>& a, const FVec<SSE,float>& b) { return _mm_sub_ps( a, b ); }
FORCE_INLINE FVec<SSE,float> operator* (const FVec<SSE,float>& a, const FVec<SSE,float>& b) { return _mm_mul_ps( a, b ); }
FORCE_INLINE FVec<SSE,float> operator/ (const FVec<SSE,float>& a, const FVec<SSE,float>& b) { return _mm_div_ps( a, b ); }
FORCE_INLINE IVec<SSE,float> ftoi (const FVec<SSE,float>& a) { return _mm_cvttps_epi32(a); }
FORCE_INLINE IVec<SSE,float> operator<= (const FVec<SSE,float>& a, const FVec<SSE,float>& b) { return _mm_castps_si128( _mm_cmple_ps( a, b ) ); }
FORCE_INLINE IVec<SSE,float> operator>= (const FVec<SSE,float>& a, const FVec<SSE,float>& b) { return _mm_castps_si128( _mm_cmpge_ps( a, b ) ); }
FORCE_INLINE IVec<SSE,float> operator< (const FVec<SSE,float>& a, const FVec<SSE,float>& b) { return _mm_castps_si128(_mm_cmplt_ps(a, b)); }
#ifdef USE_FMA
FORCE_INLINE FVec<SSE, float> mulSub(const FVec<SSE, float>& a, const FVec<SSE, float>& b, const FVec<SSE, float>& c) { return _mm_fmsub_ps(a, b, c); }
#endif
typedef VecStorage<InstrFloatTraits<SSE,double>> FVec128Double;
template <>
struct FVec1<SSE, double> : FVec128Double
{
FORCE_INLINE FVec1() {}
FORCE_INLINE FVec1( double f ) : FVec128Double( _mm_load_sd( &f ) ) {}
FORCE_INLINE FVec1( const vec_t& v ) : FVec128Double( v ) {}
FORCE_INLINE double get0() const { return _mm_cvtsd_f64( vec ); }
};
template <>
struct FVec<SSE, double> : FVec128Double
{
FORCE_INLINE FVec() {}
FORCE_INLINE FVec( double d ) : FVec128Double( _mm_set1_pd( d ) ) {}
FORCE_INLINE FVec( const double *v ) : FVec128Double( _mm_loadu_pd( v ) ) {}
FORCE_INLINE FVec( const vec_t& v) : FVec128Double( v ) {}
FORCE_INLINE FVec( double f1, double f0 ) : FVec128Double( _mm_set_pd(f1, f0) ) {}
void set0( double f ) { vec = _mm_load_sd( &f ); }
void setN( double f ) { vec = _mm_set1_pd( f ); }
FORCE_INLINE void setidx( const double *xi, const IVec<SSE,double>& idx )
{
vec = _mm_set_pd( xi[idx.get1()], xi[idx.get0()] );
}
FORCE_INLINE double get0() const { return _mm_cvtsd_f64( vec ); }
FORCE_INLINE double get1() const { return _mm_cvtsd_f64( _mm_shuffle_pd( vec, vec, 1 ) ); };
};
FORCE_INLINE FVec1<SSE,double> operator+ (const FVec1<SSE,double>& a, const FVec1<SSE,double>& b) { return _mm_add_sd( a, b ); }
FORCE_INLINE FVec1<SSE,double> operator- (const FVec1<SSE,double>& a, const FVec1<SSE,double>& b) { return _mm_sub_sd( a, b ); }
FORCE_INLINE FVec1<SSE,double> operator* (const FVec1<SSE,double>& a, const FVec1<SSE,double>& b) { return _mm_mul_sd( a, b ); }
FORCE_INLINE FVec1<SSE,double> operator/ (const FVec1<SSE,double>& a, const FVec1<SSE,double>& b) { return _mm_div_sd( a, b ); }
FORCE_INLINE int ftoi (const FVec1<SSE,double>& a) { return _mm_cvttsd_si32(a); }
FORCE_INLINE IVec<SSE,double> operator> (const FVec1<SSE,double>& a, const FVec1<SSE,double>& b) { return _mm_castpd_si128( _mm_cmpgt_sd( a, b ) ); }
#ifdef USE_FMA
FORCE_INLINE FVec1<SSE, double> mulSub(const FVec1<SSE, double>& a, const FVec1<SSE, double>& b, const FVec1<SSE, double>& c) { return _mm_fmsub_sd(a, b, c); }
#endif
FORCE_INLINE FVec<SSE,double> operator- (const FVec<SSE,double>& a, const FVec<SSE,double>& b) { return _mm_sub_pd( a, b ); }
FORCE_INLINE FVec<SSE,double> operator* (const FVec<SSE,double>& a, const FVec<SSE,double>& b) { return _mm_mul_pd( a, b ); }
FORCE_INLINE FVec<SSE,double> operator/ (const FVec<SSE,double>& a, const FVec<SSE,double>& b) { return _mm_div_pd( a, b ); }
FORCE_INLINE IVec<SSE,float> ftoi (const FVec<SSE,double>& a) { return _mm_cvttpd_epi32(a); }
FORCE_INLINE IVec<SSE,double> operator<= (const FVec<SSE,double>& a, const FVec<SSE,double>& b) { return _mm_castpd_si128( _mm_cmple_pd( a, b ) ); }
FORCE_INLINE IVec<SSE,double> operator< (const FVec<SSE,double>& a, const FVec<SSE,double>& b) { return _mm_castpd_si128(_mm_cmplt_pd(a, b)); }
FORCE_INLINE IVec<SSE,double> operator>= (const FVec<SSE,double>& a, const FVec<SSE,double>& b) { return _mm_castpd_si128( _mm_cmpge_pd( a, b ) ); }
#ifdef USE_FMA
FORCE_INLINE FVec<SSE, double> mulSub(const FVec<SSE, double>& a, const FVec<SSE, double>& b, const FVec<SSE, double>& c ) { return _mm_fmsub_pd(a, b, c); }
#endif
#ifdef USE_AVX
template <>
struct IVecBase<AVX> : VecStorage<InstrIntTraits<AVX>>
{
protected:
FORCE_INLINE IVecBase() {}
FORCE_INLINE IVecBase( const vec_t& v) : VecStorage<InstrIntTraits<AVX>>( v ) {}
public:
FORCE_INLINE static vec_t zero() { return _mm256_setzero_si256(); }
FORCE_INLINE int32 get0() const { return _mm_cvtsi128_si32(_mm256_castsi256_si128(vec)); }
FORCE_INLINE void assignIf( const vec_t& val, const vec_t& mask ) { vec = _mm256_blendv_epi8(vec, val, mask); }
FORCE_INLINE void orIf(const vec_t& val, const vec_t& mask)
{
vec = _mm256_blendv_epi8(vec, val, mask);
//vec = _mm256_or_si256(vec, _mm256_and_si256(val,mask));
}
FORCE_INLINE __m128i lo128() const { return _mm256_castsi256_si128(vec); }
FORCE_INLINE __m128i hi128() const { return _mm256_extractf128_si256(vec, 1); }
};
template <>
struct IVec<AVX, float> : IVecBase<AVX>
{
FORCE_INLINE IVec() {}
FORCE_INLINE IVec( int32 i ) : IVecBase<AVX>( _mm256_set1_epi32( i ) ) {}
FORCE_INLINE IVec( const vec_t& v) : IVecBase<AVX>( v ) {}
FORCE_INLINE IVec(uint32 u7, uint32 u6, uint32 u5, uint32 u4, uint32 u3, uint32 u2, uint32 u1, uint32 u0) : IVecBase<AVX>(_mm256_set_epi32(u7, u6, u5, u4, u3, u2, u1, u0)) {}
void setN( int32 i ) { vec = _mm256_set1_epi32( i ); }
FORCE_INLINE int32 get1() const { return _mm256_extract_epi32(vec, 1); }
FORCE_INLINE int32 get2() const { return _mm256_extract_epi32(vec, 2); }
FORCE_INLINE int32 get3() const { return _mm256_extract_epi32(vec, 3); }
FORCE_INLINE int32 get4() const { return _mm256_extract_epi32(vec, 4); }
FORCE_INLINE int32 get5() const { return _mm256_extract_epi32(vec, 5); }
FORCE_INLINE int32 get6() const { return _mm256_extract_epi32(vec, 6); }
FORCE_INLINE int32 get7() const { return _mm256_extract_epi32(vec, 7); }
FORCE_INLINE void setidx( const uint32 *bi, const IVec<AVX,float>& idx )
{
vec = _mm256_i32gather_epi32(reinterpret_cast<const int32 *>(bi), idx, sizeof(uint32));
}
FORCE_INLINE void store( uint32 *pi ) const { _mm256_storeu_si256( reinterpret_cast<vec_t*>(pi), vec ); }
FORCE_INLINE int countbit()
{
return popcnt32(_mm256_movemask_ps(_mm256_castsi256_ps(vec)));
}
};
template <>
struct IVec<AVX, double> : IVecBase<AVX>
{
FORCE_INLINE IVec() {}
FORCE_INLINE IVec( int32 i ) : IVecBase<AVX>( _mm256_set1_epi64x( i ) ) {}
FORCE_INLINE IVec( const vec_t& v) : IVecBase<AVX>( v ) {}
FORCE_INLINE IVec(uint64 u3, uint64 u2, uint64 u1, uint64 u0) : IVecBase<AVX>(_mm256_set_epi64x(u3, u2, u1, u0)) {}
void setN( int32 i ) { vec = _mm256_set1_epi64x( i ); }
// extract the 4 32 bits integers no. 0, 2, 4, 6 and store them in a __m128i
FORCE_INLINE IVec<SSE,float> extractLo32s() const
{
union {
uint32 u32[4];
__m128i u;
} mask = {0,2,4,6};
//__m256 ps256 = _mm256_castsi256_ps(vec);
//__m128 lo128 = _mm256_castps256_ps128(ps256);
//__m128 hi128 = _mm256_extractf128_ps(ps256, 1);
//__m128 blend = _mm_shuffle_ps(lo128, hi128, 0 + (2<<2) + (0<<4) + (2<<6));
__m256i blend = _mm256_permutevar8x32_epi32(vec, _mm256_castsi128_si256(mask.u));
return _mm256_castsi256_si128(blend);
}
//int32 get1() const { return _mm256_cvtsi256_si32( _mm256_shuffle_epi32( vec, 2 ) ); };
FORCE_INLINE int32 get1() const { return _mm256_extract_epi32(vec, 2); }
FORCE_INLINE void store( uint32 *pi ) const
{
extractLo32s().store(pi);
}
FORCE_INLINE int countbit()
{
return popcnt32(_mm256_movemask_pd(_mm256_castsi256_pd(vec)));
}
};
template <typename T>
FORCE_INLINE IVec<AVX,T> operator>> (const IVec<AVX,T>& a, unsigned n) { return _mm256_srli_epi32(a, n); }
template <typename T>
FORCE_INLINE IVec<AVX,T> operator<< (const IVec<AVX,T>& a, unsigned n) { return _mm256_slli_epi32(a, n); }
template <typename T>
FORCE_INLINE IVec<AVX,T> operator& (const IVec<AVX,T>& a, const IVec<AVX,T>& b ) { return _mm256_and_si256( a, b ); }
template <typename T>
FORCE_INLINE IVec<AVX,T> operator| (const IVec<AVX,T>& a, const IVec<AVX,T>& b ) { return _mm256_or_si256( a, b ); }
template <typename T>
FORCE_INLINE IVec<AVX,T> operator^ (const IVec<AVX,T>& a, const IVec<AVX,T>& b ) { return _mm256_xor_si256( a, b ); }
template <typename T>
FORCE_INLINE IVec<AVX,T> min (const IVec<AVX,T>& a, const IVec<AVX,T>& b ) { return _mm256_min_epi32( a, b ); }
FORCE_INLINE IVec<AVX,float> operator+ (const IVec<AVX,float>& a, const IVec<AVX,float>& b ) { return _mm256_add_epi32( a, b ); }
FORCE_INLINE IVec<AVX,float> operator- (const IVec<AVX,float>& a, const IVec<AVX,float>& b ) { return _mm256_sub_epi32( a, b ); }
FORCE_INLINE IVec<AVX,double> operator+ (const IVec<AVX,double>& a, const IVec<AVX,double>& b ) { return _mm256_add_epi64( a, b ); }
FORCE_INLINE IVec<AVX,double> operator- (const IVec<AVX,double>& a, const IVec<AVX,double>& b ) { return _mm256_sub_epi64( a, b ); }
typedef VecStorage<InstrFloatTraits<AVX,float>> FVec256Float;
template <>
struct FVec<AVX, float> : FVec256Float
{
FORCE_INLINE FVec() {}
FORCE_INLINE FVec( float f ) : FVec256Float( _mm256_set1_ps( f ) ) {}
FORCE_INLINE FVec( const float *v ) : FVec256Float( _mm256_loadu_ps( v ) ) {}
FORCE_INLINE FVec( const vec_t& v) : FVec256Float(v) {}
FORCE_INLINE FVec(float f7, float f6, float f5, float f4, float f3, float f2, float f1, float f0) : FVec256Float(_mm256_set_ps(f7, f6, f5, f4, f3, f2, f1, f0)) {}
//void set0( float f ) { vec = _mm256_load_ss( &f ); }
void setN( float f ) { vec = _mm256_set1_ps( f ); }
FORCE_INLINE void setidx( const float *xi, const IVec<AVX,float>& idx )
{
#if 1 // use gather primitives
vec = _mm256_i32gather_ps (xi, idx, 4);
#elif 0
uint32 i0 = idx.get0();
uint32 i1 = idx.get1();
uint32 i2 = idx.get2();
uint32 i3 = idx.get3();
uint32 i4 = idx.get4();
uint32 i5 = idx.get5();
uint32 i6 = idx.get6();
uint32 i7 = idx.get7();
vec = _mm256_set_ps( xi[i7], xi[i6], xi[i5], xi[i4], xi[i3], xi[i2], xi[i1], xi[i0] );
#else
union {
__m256i vec;
uint32 ui32[8];
} i;
i.vec = static_cast<const __m256i&>(idx);
vec = _mm256_set_ps(xi[i.ui32[7]], xi[i.ui32[6]], xi[i.ui32[5]], xi[i.ui32[4]], xi[i.ui32[3]], xi[i.ui32[2]], xi[i.ui32[1]], xi[i.ui32[0]]);
#endif
}
FORCE_INLINE FVec<SSE, float> lo128() const { return _mm256_castps256_ps128(vec); }
FORCE_INLINE FVec<SSE, float> hi128() const { return _mm256_extractf128_ps(vec, 1); }
//FORCE_INLINE float get0() const { return _mm256_cvtss_f32( vec ); }
//FORCE_INLINE float get1() const { return _mm256_cvtss_f32( _mm256_shuffle_ps( vec, vec, 1 ) ); }
//FORCE_INLINE float get2() const { return _mm256_cvtss_f32( _mm256_shuffle_ps( vec, vec, 2 ) ); }
//FORCE_INLINE float get3() const { return _mm256_cvtss_f32( _mm256_shuffle_ps( vec, vec, 3 ) ); }
};
FORCE_INLINE FVec<AVX,float> operator- (const FVec<AVX,float>& a, const FVec<AVX,float>& b) { return _mm256_sub_ps( a, b ); }
FORCE_INLINE FVec<AVX,float> operator* (const FVec<AVX,float>& a, const FVec<AVX,float>& b) { return _mm256_mul_ps( a, b ); }
FORCE_INLINE FVec<AVX,float> operator/ (const FVec<AVX,float>& a, const FVec<AVX,float>& b) { return _mm256_div_ps( a, b ); }
FORCE_INLINE IVec<AVX,float> ftoi (const FVec<AVX,float>& a) { return _mm256_cvttps_epi32(a); }
FORCE_INLINE IVec<AVX,float> operator<= (const FVec<AVX,float>& a, const FVec<AVX,float>& b) { return _mm256_castps_si256( _mm256_cmp_ps( a, b, _CMP_LE_OS) ); }
FORCE_INLINE IVec<AVX,float> operator>= (const FVec<AVX,float>& a, const FVec<AVX,float>& b) { return _mm256_castps_si256( _mm256_cmp_ps( a, b, _CMP_GE_OS ) ); }
FORCE_INLINE IVec<AVX,float> operator< (const FVec<AVX,float>& a, const FVec<AVX,float>& b) { return _mm256_castps_si256(_mm256_cmp_ps(a, b, _CMP_LT_OS )); }
#ifdef USE_FMA
FORCE_INLINE FVec<AVX, float> mulSub(const FVec<AVX, float>& a, const FVec<AVX, float>& b, const FVec<AVX, float>& c) { return _mm256_fmsub_ps(a, b, c); }
#endif
typedef VecStorage<InstrFloatTraits<AVX,double>> FVec256Double;
template <>
struct FVec<AVX, double> : FVec256Double
{
FORCE_INLINE FVec() {}
FORCE_INLINE FVec( double d ) : FVec256Double( _mm256_set1_pd( d ) ) {}
FORCE_INLINE FVec( const double *v ) : FVec256Double( _mm256_loadu_pd( v ) ) {}
FORCE_INLINE FVec( const vec_t& v) : FVec256Double( v ) {}
FORCE_INLINE FVec(double d3, double d2, double d1, double d0) : FVec256Double(_mm256_set_pd(d3, d2, d1, d0)) {}
//void set0( double f ) { vec = _mm256_load_sd( &f ); }
void setN( double f ) { vec = _mm256_set1_pd( f ); }
FORCE_INLINE void setidx( const double *xi, const IVec<SSE,float>& idx )
{
vec = _mm256_i32gather_pd(xi, idx, 8);
}
FORCE_INLINE void setidx( const double *xi, const IVec<AVX,double>& idx )
{
vec = _mm256_i64gather_pd(xi, idx, 8);
}
// FORCE_INLINE double get0() const { return _mm256_cvtsd_f64( vec ); }
// FORCE_INLINE double get1() const { return _mm256_cvtsd_f64( _mm256_shuffle_pd( vec, vec, 1 ) ); };
};
FORCE_INLINE FVec<AVX,double> operator- (const FVec<AVX,double>& a, const FVec<AVX,double>& b) { return _mm256_sub_pd( a, b ); }
FORCE_INLINE FVec<AVX,double> operator* (const FVec<AVX,double>& a, const FVec<AVX,double>& b) { return _mm256_mul_pd( a, b ); }
FORCE_INLINE FVec<AVX,double> operator/ (const FVec<AVX,double>& a, const FVec<AVX,double>& b) { return _mm256_div_pd( a, b ); }
FORCE_INLINE IVec<SSE,float> ftoi (const FVec<AVX,double>& a) { return _mm256_cvttpd_epi32(a); }
FORCE_INLINE IVec<AVX,double> operator<= (const FVec<AVX,double>& a, const FVec<AVX,double>& b) { return _mm256_castpd_si256(_mm256_cmp_pd( a, b, _CMP_LE_OS ) ); }
FORCE_INLINE IVec<AVX,double> operator< (const FVec<AVX,double>& a, const FVec<AVX,double>& b) { return _mm256_castpd_si256(_mm256_cmp_pd(a, b, _CMP_LT_OS)); }
FORCE_INLINE IVec<AVX,double> operator>= (const FVec<AVX,double>& a, const FVec<AVX,double>& b) { return _mm256_castpd_si256(_mm256_cmp_pd( a, b, _CMP_GE_OS ) ); }
#ifdef USE_FMA
FORCE_INLINE FVec<AVX, double> mulSub(const FVec<AVX, double>& a, const FVec<AVX, double>& b, const FVec<AVX, double>& c) { return _mm256_fmsub_pd(a, b, c); }
#endif
#endif
} // namepsace Details
} // namespace BinSearch
#pragma once
#include <stddef.h>
#include <vector>
#include <limits>
#include "Portable.h"
using std::size_t;
namespace BinSearch {
enum InstrSet { Scalar, SSE, AVX };
#define ALGOENUM(x, b) x,
enum Algos
{
#include "AlgoXCodes.h"
};
#undef ALGOENUM
namespace Details {
template <InstrSet I>
struct InstrIntTraits;
template <InstrSet I, typename T>
struct InstrFloatTraits;
// base class for algorithm supporting the method:
// uint32 scalar(T z) const
template <typename T, Algos A, typename Enable=void>
struct AlgoScalarBase;
// base class for algorithm supporting the following methods, constants and definitions:
// static const uint32 nElem
// struct Constants;
// void initConstants(Constants& cst) const
// void vectorial(uint32 *pr, const T *pz, const Constants& cst) const
// The function vectorial processes nElem items
template <InstrSet I, typename T, Algos A, typename Enable=void>
struct AlgoVecBase;
template <typename T> struct IntTraits;
template <> struct IntTraits<float>
{
typedef uint32 itype;
};
template <> struct IntTraits<double>
{
typedef uint64 itype;
};
template <int N>
struct Body
{
template <uint32 D, typename T, typename Expr>
FORCE_INLINE static void iteration(const Expr& e, uint32 *ri, const T* zi, const typename Expr::Constants& cst)
{
e.vectorial(ri, zi, cst);
Body<N - 1>::template iteration<D>(e, ri + D, zi + D, cst);
}
};
template <>
struct Body<0>
{
template <uint32 D, typename T, typename Expr, typename H>
FORCE_INLINE static void iteration(const Expr& e, uint32 *ri, const T* zi, const H&)
{
}
};
template <typename T, typename Algo>
struct Loop
{
typedef Algo algo_type;
static const uint32 M = 4;
static const uint32 D = algo_type::nElem;
FORCE_INLINE static void loop(const algo_type& e, uint32 *ri, const T* zi, uint32 n)
{
typename algo_type::Constants cst;
e.initConstants(cst);
uint32 j = 0;
while (j + (D*M) <= n) {
Details::Body<M>::template iteration<D>(e, ri + j, zi + j, cst);
j += (D*M);
}
while (j + D <= n) {
e.vectorial(ri + j, zi + j, cst);
j += D;
}
while (D > 1 && j < n) {
ri[j] = e.scalar(zi[j]);
j += 1;
}
}
};
template <uint32 nIterTot, uint32 nIterLeft>
struct _Pipeliner
{
template <typename Expr, typename Data>
FORCE_INLINE static void go(const Expr& e, Data* d)
{
e.template run<nIterTot - nIterLeft>(d);
_Pipeliner<nIterTot, nIterLeft - 1>::go(e, d);
}
};
template <uint32 nIterTot>
struct _Pipeliner<nIterTot, 0>
{
template <typename Expr, typename Data>
FORCE_INLINE static void go(const Expr& e, Data* d)
{
}
};
template <uint32 nIter>
struct Pipeliner
{
template <typename Expr, typename Data>
FORCE_INLINE static void go(const Expr& e, Data* d)
{
_Pipeliner<nIter, nIter>::go(e, d);
}
};
#if 1
template <class T>
char is_complete_impl(char (*)[sizeof(T)]);
template <class>
long is_complete_impl(...);
template <class T>
struct IsComplete
{
static const bool value = sizeof(is_complete_impl<T>(0)) == sizeof(char);
};
#else
template <class T, std::size_t = sizeof(T)>
std::true_type is_complete_impl(T *);
std::false_type is_complete_impl(...);
template <class T>
struct IsComplete : decltype(is_complete_impl(std::declval<T*>())) {};
#endif
template <typename T, Algos A>
struct AlgoScalarToVec : AlgoScalarBase<T,A>
{
typedef AlgoScalarBase<T, A> base_t;
AlgoScalarToVec(const typename base_t::Data& d) : base_t(d) {}
AlgoScalarToVec(const T* px, const uint32 n) : base_t(px, n) {}
static const uint32 nElem = 1;
struct Constants
{
};
void initConstants(Constants& cst) const
{
}
FORCE_INLINE
void vectorial(uint32 *pr, const T *pz, const Constants& cst) const
{
*pr = base_t::scalar(*pz);
}
};
template<bool B, class T, class F>
struct conditional { typedef T type; };
template<class T, class F>
struct conditional<false, T, F> { typedef F type; };
template <typename T, bool C>
struct CondData
{
FORCE_INLINE CondData(T x) : v(x) {}
FORCE_INLINE operator const T&() const { return v;}
private:
T v;
};
template <typename T>
struct CondData<T,false>
{
FORCE_INLINE CondData(T) {}
FORCE_INLINE operator const T() const { return 0;}
};
template <InstrSet I, typename T, Algos A, bool L=false>
struct BinAlgoBase : Details::conditional< Details::IsComplete<Details::AlgoVecBase<I, T, A>>::value
, Details::AlgoVecBase<I, T, A>
, Details::AlgoScalarToVec<T,A>
>::type
{
typedef typename Details::conditional< Details::IsComplete<Details::AlgoVecBase<I, T, A>>::value
, Details::AlgoVecBase<I, T, A>
, Details::AlgoScalarToVec<T,A>
>::type base_t;
BinAlgoBase(const T* px, const uint32 n) : base_t(px, n) {}
BinAlgoBase(const typename base_t::Data& d) : base_t(d) {}
};
} // namespace Details
} // namespace BinSearch
[build-system]
requires = [
"setuptools>=42",
"wheel"
]
build-backend = "setuptools.build_meta"
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import glob
import os
from setuptools import find_packages, setup
libs = list(glob.glob("./bitsandbytes/libbitsandbytes*.so"))
libs = [os.path.basename(p) for p in libs]
print("libs:", libs)
def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
setup(
name=f"bitsandbytes",
version=f"0.35.4",
author="Tim Dettmers",
author_email="dettmers@cs.washington.edu",
description="8-bit optimizers and matrix multiplication routines.",
license="MIT",
keywords="gpu optimizers optimization 8-bit quantization compression",
url="https://github.com/TimDettmers/bitsandbytes",
packages=find_packages(),
entry_points={
"console_scripts": ["debug_cuda = bitsandbytes.debug_cli:cli"],
},
package_data={"": libs},
long_description=read("README.md"),
long_description_content_type="text/markdown",
classifiers=[
"Development Status :: 4 - Beta",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)
from itertools import product, permutations
import pytest
import torch
import bitsandbytes as bnb
n = 1
k = 25
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)]
str_funcs = ["bmm", "matmul"]
req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad_str = ["FF", "TF", "TT", "FT"]
transpose = [(False, False), (False, True), (True, True), (True, False)]
str_transpose = ["FF", "FT", "TT", "TF"]
dtype = [torch.float32, torch.float16]
values = list(
product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)
)
str_values = list(
product(
dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose
)
)
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format(
*vals
)
for vals in str_values
]
@pytest.mark.parametrize(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose",
values,
ids=names,
)
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if not torch.cuda.is_available(): pytest.skip('No GPU found.')
if dim2 > 0:
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
for i in range(k):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(
size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]
)
torch.nn.init.xavier_uniform_(B)
if not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)
elif not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B.t())
elif transpose[0] and not transpose[1]:
out_torch = funcs[0](A.t(), B)
out_bnb = funcs[1](A.t(), B)
elif transpose[0] and transpose[1]:
out_torch = funcs[0](A.t(), B.t())
out_bnb = funcs[1](A.t(), B.t())
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx == 0).sum().item() < n * 0.001
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(
out_torch, target
).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_allclose(
gradA1, gradA2, atol=0.015, rtol=0.1
)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
torch.testing.assert_allclose(
gradB1, gradB2, atol=0.18, rtol=0.3
)
# batched matrix multiply
if funcs[0] in [torch.bmm, torch.matmul]:
A = torch.randn(
size=(dim1, dim2, dim3),
device="cuda",
requires_grad=req_grad[0],
)
B = torch.randn(
size=(dim1, dim3, dim4),
device="cuda",
requires_grad=req_grad[1],
)
target = torch.randn(
size=(dim1, dim2, dim4),
device="cuda",
requires_grad=req_grad[1],
)
torch.nn.init.xavier_uniform_(B)
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.01
torch.testing.assert_allclose(
out_bnb, out_torch, atol=0.027, rtol=0.2
)
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(
out_torch, target
).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_allclose(
gradA1, gradA2, atol=0.015, rtol=0.1
)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
if funcs[0] in [torch.matmul]:
dim1 = dim1 - (dim1 % 16)
A = torch.randn(
size=(dim1, dim2, dim3),
device="cuda",
requires_grad=req_grad[0],
)
dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(
size=(dim1, dim2, dim4),
device="cuda",
requires_grad=req_grad[1],
)
torch.nn.init.xavier_uniform_(B)
if transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B.t())
else:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx == 0).sum().item() < n * 0.001
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(
out_torch, target
).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_allclose(
gradA1, gradA2, atol=0.015, rtol=0.1
)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
n = 1
k = 3
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
dim2.append(0)
decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)]
str_funcs = ["matmul"]
req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
strval = ''
for v in c:
if v == True: strval += 'T'
else: strval += 'F'
req_grad_str.append(strval)
transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16, torch.bfloat16, torch.float32]
has_fp16_weights = [True, False]
has_bias = [True, False]
values = list(
product(
dim1,
dim2,
dim3,
dim4,
funcs,
dtype,
req_grad,
transpose,
decomp,
has_fp16_weights,
has_bias
)
)
str_values = list(
product(
dim1,
dim2,
dim3,
dim4,
str_funcs,
dtype,
req_grad_str,
str_transpose,
decomp,
has_fp16_weights,
has_bias
)
)
names = ["dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}_has_bias_{10}".format(*vals) for vals in str_values]
@pytest.mark.parametrize(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias",
values,
ids=names,
)
def test_matmullt(
dim1,
dim2,
dim3,
dim4,
funcs,
dtype,
req_grad,
transpose,
decomp,
has_fp16_weights,
has_bias
):
if not torch.cuda.is_available(): pytest.skip('No GPU found.')
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
if has_bias == False:
req_grad = list(req_grad)
req_grad[2] = False
for i in range(k):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(
size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype
)
if decomp == 6.0:
with torch.no_grad():
A[:, outlier_dim] = 6.0
B = torch.randn(
size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
)
target = torch.randn(
size=(dim2, dim4),
device="cuda",
requires_grad=req_grad[1],
dtype=dtype,
)
bias = None
bias2 = None
if has_bias:
bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone()
torch.nn.init.xavier_uniform_(B)
B2 = B.clone()
state = bnb.MatmulLtState()
state.threshold = decomp
state.has_fp16_weights = has_fp16_weights
if not has_fp16_weights:
if not transpose[0] and not transpose[1]:
B2 = B2.t().contiguous()
(
state.CB,
CBt,
state.SCB,
SCBt,
coo_tensorB,
) = bnb.functional.double_quant(B2.to(torch.float16))
B2 = state.CB
if not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B2, state=state, bias=bias2)
elif not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B2.t(), state=state, bias=bias2)
if has_bias:
out_torch += bias
assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
n = out_bnb.numel()
err = torch.abs(out_bnb - out_torch).mean().item()
# print(f'abs error {err:.4f}')
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() <= n * (0.0175 if dtype == torch.float16 else 0.021)
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx == 0).sum().item() <= n * 0.001
if has_fp16_weights:
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(
out_bnb, target
).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
if has_bias:
gradBias1 = bias.grad
bias.grad = None
loss_torch = torch.nn.functional.mse_loss(
out_torch, target
).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if has_bias:
gradBias2 = bias.grad
bias.grad = None
if req_grad[0]:
torch.testing.assert_allclose(
gradA1, gradA2, atol=0.015, rtol=0.1
)
if req_grad[1]:
n = gradB1.numel()
if dim2 > 0:
assert torch.abs(gradB1).sum() > 0.0
assert torch.abs(gradB2).sum() > 0.0
else:
assert torch.abs(gradB1).sum() == 0.0
assert torch.abs(gradB2).sum() == 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.02
torch.testing.assert_allclose(
gradB1, gradB2, atol=0.18, rtol=0.3
)
if req_grad[2]:
torch.testing.assert_allclose(gradBias1, gradBias2)
import os
import pytest
import bitsandbytes as bnb
from typing import List, NamedTuple
from bitsandbytes.cuda_setup import (
CUDA_RUNTIME_LIB,
evaluate_cuda_setup,
determine_cuda_runtime_lib_path,
extract_candidate_paths,
)
"""
'LD_LIBRARY_PATH': ':/mnt/D/titus/local/cuda-11.1/lib64/'
'CONDA_EXE': '/mnt/D/titus/miniconda/bin/conda'
'LESSCLOSE': '/usr/bin/lesspipe %s %s'
'OLDPWD': '/mnt/D/titus/src'
'CONDA_PREFIX': '/mnt/D/titus/miniconda/envs/8-bit'
'SSH_AUTH_SOCK': '/mnt/D/titus/.ssh/ssh-agent.tim-uw.sock'
'CONDA_PREFIX_1': '/mnt/D/titus/miniconda'
'PWD': '/mnt/D/titus/src/8-bit'
'HOME': '/mnt/D/titus'
'CONDA_PYTHON_EXE': '/mnt/D/titus/miniconda/bin/python'
'CUDA_HOME': '/mnt/D/titus/local/cuda-11.1/'
'TMUX': '/tmp/tmux-1007/default,59286,1'
'XDG_DATA_DIRS': '/usr/local/share:/usr/share:/var/lib/snapd/desktop'
'SSH_TTY': '/dev/pts/0'
'MAIL': '/var/mail/titus'
'SHELL': '/bin/bash'
'DBUS_SESSION_BUS_ADDRESS': 'unix:path=/run/user/1007/bus'
'XDG_RUNTIME_DIR': '/run/user/1007'
'PATH': '/mnt/D/titus/miniconda/envs/8-bit/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/mnt/D/titus/local/cuda-11.1/bin'
'LESSOPEN': '| /usr/bin/lesspipe %s'
'_': '/mnt/D/titus/miniconda/envs/8-bit/bin/python'
# any that include 'CONDA' that are not 'CONDA_PREFIX'
# we search for
'CUDA_HOME': '/mnt/D/titus/local/cuda-11.1/'
"""
class InputAndExpectedOutput(NamedTuple):
input: str
output: str
HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [
(
f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
]
@pytest.fixture(params=HAPPY_PATH__LD_LIB_TEST_PATHS)
def happy_path_path_string(tmpdir, request):
for path in extract_candidate_paths(request.param):
test_dir.mkdir()
if CUDA_RUNTIME_LIB in path:
(test_input / CUDA_RUNTIME_LIB).touch()
UNHAPPY_PATH__LD_LIB_TEST_PATHS = [
f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}",
f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}:g/h/j/{CUDA_RUNTIME_LIB}",
]
def test_full_system():
## this only tests the cuda version and not compute capability
# if CONDA_PREFIX exists, it has priority before all other env variables
# but it does not contain the library directly, so we need to look at the a sub-folder
version = ""
if "CONDA_PREFIX" in os.environ:
ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so')
major, minor, revision = (ls_output.split(" ")[-1].replace("libcudart.so.", "").split("."))
version = float(f"{major}.{minor}")
if version == "" and "LD_LIBRARY_PATH" in os.environ:
ld_path = os.environ["LD_LIBRARY_PATH"]
paths = ld_path.split(":")
version = ""
for p in paths:
if "cuda" in p:
idx = p.rfind("cuda-")
version = p[idx + 5 : idx + 5 + 4].replace("/", "")
version = float(version)
break
assert version > 0
binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup()
binary_name = binary_name.replace("libbitsandbytes_cuda", "")
assert binary_name.startswith(str(version).replace(".", ""))
import math
import random
import time
from itertools import product
import einops
import pytest
import torch
import numpy as np
import bitsandbytes as bnb
from bitsandbytes import functional as F
from scipy.stats import norm
torch.set_printoptions(
precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
)
k = 20
def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0):
idx = torch.isclose(a, b, rtol, atol)
sumval = (idx == 0).sum().item()
if sumval > count:
print(f"Too many values not close: assert {sumval} < {count}")
torch.testing.assert_allclose(a, b, rtol, atol)
class FFN(torch.nn.Module):
def __init__(self, input_features, hidden_size, bias=True):
super(FFN, self).__init__()
self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)
with torch.no_grad():
torch.nn.init.xavier_uniform_(self.fc1.weight)
torch.nn.init.xavier_uniform_(self.fc2.weight)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
class Timer(object):
def __init__(self):
self.starts = {}
self.ends = {}
self.agg = {}
def tick(self, name="default"):
if name not in self.starts:
self.starts[name] = torch.cuda.Event(enable_timing=True)
self.ends[name] = torch.cuda.Event(enable_timing=True)
self.starts[name].record()
else:
ms = self.tock(name, evict=True, print_ms=False)
def tock(self, name="default", evict=True, print_ms=True):
if name in self.ends:
self.ends[name].record()
torch.cuda.synchronize()
ms = self.starts[name].elapsed_time(self.ends[name])
if name not in self.agg:
self.agg[name] = 0.0
self.agg[name] += ms
if evict:
self.starts.pop(name)
self.ends.pop(name)
if print_ms and name in self.agg:
print("{0} took: {1:.5f}s".format(name, self.agg[name] / 1000.0))
return self.agg[name]
def reset(self):
self.starts = {}
self.ends = {}
self.agg = {}
print("Resetting benchmark data")
def setup():
pass
def teardown():
pass
@pytest.mark.parametrize(
"dtype", [torch.float32, torch.float16], ids=["float", "half"]
)
def test_estimate_quantiles(dtype):
A = torch.rand(1024, 1024, device="cuda")
A = A.to(dtype)
code = F.estimate_quantiles(A)
percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2)
A = torch.randn(1024, 1024, device="cuda")
A = A.to(dtype)
code = F.estimate_quantiles(A)
quantiles = torch.quantile(A.float(), percs)
diff = torch.abs(code - quantiles)
assert (diff > 5e-02).sum().item() == 0
def test_quantile_quantization():
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1 - A2).mean().item()
assert diff < 0.0075
A1 = torch.rand(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0)
assert diff < 0.001
def test_dynamic_quantization():
diffs = []
reldiffs = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diff.mean().item() < 0.0135
# print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs))
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
assert diff < 0.004
def test_dynamic_blockwise_quantization():
#print('')
for blocksize in [4096, 2048, 1024, 512]:
diffs = []
reldiffs = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
abserr = sum(diffs)/len(diffs)
relerr = sum(reldiffs)/len(reldiffs)
assert abserr < 0.011
assert relerr < 0.018
#print('randn', blocksize, sum(diffs)/len(diffs))
#print('randn', blocksize, sum(reldiffs)/len(reldiffs))
diffs = []
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
#torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
abserr = sum(diffs)/len(diffs)
relerr = sum(reldiffs)/len(reldiffs)
assert abserr < 0.0035
assert relerr < 0.015
#print('rand', blocksize, sum(diffs)/len(diffs))
#print('rand', blocksize, sum(reldiffs)/len(reldiffs))
def test_dynamic_blockwise_stochastic_quantization():
diffs = []
reldiffs = []
rand = torch.rand(1024).cuda()
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C1, S1 = F.quantize_blockwise(A1, rand=rand)
C2, S2 = F.quantize_blockwise(A1)
# a maximunm distance of quantized values of 1
torch.testing.assert_allclose(C1, C2, atol=1, rtol=0)
fraction_smaller = (C1 < C2).float().sum() / C1.numel()
fraction_larger = (C1 > C2).float().sum() / C1.numel()
torch.testing.assert_allclose(
fraction_larger, fraction_smaller, atol=0.01, rtol=0
)
@pytest.mark.parametrize(
"gtype", [torch.float32, torch.float16], ids=["float", "half"]
)
def test_percentile_clipping(gtype):
gnorm_vec1 = torch.zeros(100, device="cuda")
gnorm_vec2 = torch.zeros(100, device="cuda")
n = 4
step = 0
percentile = 5
for i in range(k):
step += 1
g = torch.randn(n, n, dtype=gtype, device="cuda")
gnorm1, clip2, gnorm_scale = F.percentile_clipping(
g, gnorm_vec2, step, percentile=percentile
)
assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
gnorm2 = torch.norm(g.float())
if step == 1:
gnorm_vec1[:] = gnorm2
else:
gnorm_vec1[step % 100] = gnorm2
vals, idx = torch.sort(gnorm_vec1)
clip1 = vals[percentile]
torch.testing.assert_allclose(gnorm_vec1, torch.sqrt(gnorm_vec2))
torch.testing.assert_allclose(clip1, clip2)
torch.testing.assert_allclose(gnorm1, gnorm2)
def quant(x):
max1 = torch.abs(x).max()
x = torch.round(x / max1 * 127)
return max1, x.to(torch.int8)
def dequant(c, maxC):
return c.float() * (maxC / 127)
def mm_dequant(maxA, maxB, C):
return C.float() * (maxA / 127) * (maxB / 127)
def quant_multi(x, dim):
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
max1[max1 == 0] = 1.0
x = torch.round(x / max1 * 127)
return max1, x.to(torch.int8)
def quant_multi_chunk(x, dim, chunk_size=32):
if dim == 1:
x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
max1 = torch.tile(max1, (1, 1, x.shape[1]))
max1 = max1.view(x.shape)
elif dim == 0:
x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
max1 = torch.tile(max1, (x.shape[0], 1, 1))
max1 = max1.view(x.shape)
max1[max1 == 0] = 1.0
x = torch.round(x / max1 * 127)
return max1, x.to(torch.int8)
def quant_minmax(A):
minA = A.min()
maxA = A.max()
def mean(xx):
return sum(xx) / float(len(xx))
# dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
# dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
dim1 = [1024 * 2]
dim2 = [1024 * 16]
methods = [
(
lambda x, dim: quant(x),
lambda x, dim: quant(x),
dequant,
dequant,
mm_dequant,
)
]
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
method_names = ["linear", "vectorwise"]
batched = [False, True]
values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1, dim2, method_names, batched))
names = [
"dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals)
for vals in values_names
]
@pytest.mark.parametrize(
"dim1, dim2, quant_methods, batched", values, ids=names
)
def test_approx_igemm(dim1, dim2, quant_methods, batched):
dim1 = dim1 - (dim1 % 32)
dim2 = dim2 - (dim2 % 32)
errors = []
relerrors = []
print("")
for i in range(5):
if batched:
A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
maxA, Ac = quant_methods[0](A, 2)
maxB, Bc = quant_methods[1](B, 1)
else:
A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
maxA, Ac = quant_methods[0](A, 1)
maxB, Bc = quant_methods[1](B, 0)
torch.testing.assert_allclose(
quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
)
if batched:
out2 = torch.bmm(A, B)
C = torch.bmm(Ac.float(), Bc.float())
else:
out2 = torch.mm(A, B)
C = F.igemm(Ac, Bc)
out = quant_methods[4](maxA, maxB, C)
std = out2.std()
out /= std
out2 /= std
err = torch.abs(out - out2)
relerr = err / torch.abs(out2)
errors.append(err.mean().item())
relerrors.append(relerr.mean().item())
print(mean(errors))
print(mean(relerrors))
def test_stable_embedding():
layer = bnb.nn.StableEmbedding(1024, 1024)
layer.reset_parameters()
n = 2
hidden_dim = torch.randint(32, 256, size=(n,)).tolist()
batch_dim = torch.randint(16, 256, size=(n,)).tolist()
seq_dim = torch.randint(16, 256, size=(n,)).tolist()
transpose = [(False, False), (False, True), (True, False), (True, True)]
values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
names = [
"hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}".format(*vals)
for vals in values
]
@pytest.mark.parametrize(
"hidden_dim, batch_dim, transpose, seq_dim", values, ids=names
)
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 16)
seq_dim = seq_dim - (seq_dim % 16)
for i in range(k):
shapeA = (
(batch_dim, hidden_dim)
if not transpose[0]
else (hidden_dim, batch_dim)
)
shapeB = (
(32 * random.randint(1, 4), hidden_dim)
if transpose[1]
else (hidden_dim, 32 * random.randint(1, 4))
)
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
out2 = torch.matmul(A.float(), B.float())
out = F.igemm(A, B)
elif not transpose[0] and transpose[1]:
out2 = torch.matmul(A.float(), B.t().float())
out = F.igemm(A, B.t())
elif transpose[0] and not transpose[1]:
out2 = torch.matmul(A.t().float(), B.float())
out = F.igemm(A.t(), B)
elif transpose[0] and transpose[1]:
out2 = torch.matmul(A.t().float(), B.t().float())
out = F.igemm(A.t(), B.t())
torch.testing.assert_allclose(out.float(), out2)
for i in range(k):
shapeA = (batch_dim, seq_dim, hidden_dim)
shapeB = (
(32 * random.randint(1, 4), hidden_dim)
if transpose[1]
else (hidden_dim, 32 * random.randint(1, 4))
)
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
out2 = torch.matmul(A.float(), B.float())
out = F.igemm(A, B)
elif not transpose[0] and transpose[1]:
out2 = torch.matmul(A.float(), B.t().float())
out = F.igemm(A, B.t())
torch.testing.assert_allclose(out.float(), out2)
n = 3
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim, hidden_dim, batch_dim))
names = [
"seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values
]
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
seq_dim = seq_dim - (seq_dim % 32)
hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 2)
for i in range(25):
A = torch.randint(
-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
).to(torch.int8)
B = torch.randint(
-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda"
).to(torch.int8)
out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
iout = torch.empty(
A.shape[2], B.shape[2], dtype=torch.int32, device=A.device
)
out = F.igemm(A, B, out=iout)
torch.testing.assert_allclose(out.float(), out2)
n = 2
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
transpose = [False, True]
values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
names = [
"seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}".format(*vals)
for vals in values
]
@pytest.mark.parametrize(
"seq_dim, hidden_dim, batch_dim, transpose", values, ids=names
)
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
def min_max(x):
maxA = torch.amax(x, dim=2, keepdim=True)
minA = torch.amin(x, dim=2, keepdim=True)
scale = (maxA - minA) / 2.0
return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
seq_dim = seq_dim - (seq_dim % 16)
hidden_dim = hidden_dim - (hidden_dim % 16)
batch_dim = batch_dim - (batch_dim % 2)
errs = []
relerrs = []
errs2 = []
relerrs2 = []
for i in range(k):
A = torch.normal(
0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
)
if transpose:
B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
else:
B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
Ac, minA, scale = min_max(A)
if transpose:
maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
out = F.igemm(Ac, Bc.t())
out2 = torch.matmul(A, B.t())
offset = B.t().sum(0) * (minA + scale)
out = out.float()
out = (out * maxB.t() * scale / (127 * 127)) + offset
maxA, Ac = quant_multi(A, dim=2)
out3 = F.igemm(Ac, Bc.t())
out3 = mm_dequant(maxA, maxB.t(), out3)
else:
maxB, Bc = quant_multi(B, dim=0)
offset = B.sum(0) * (minA + scale)
out = F.igemm(Ac, Bc)
out2 = torch.matmul(A, B)
out = out.float()
out = (out * maxB * scale / (127 * 127)) + offset
maxA, Ac = quant_multi(A, dim=2)
out3 = F.igemm(Ac, Bc)
out3 = mm_dequant(maxA, maxB, out3)
std = out2.std()
out2 /= std
out /= std
out3 /= std
err = torch.abs(out - out2)
relerr = err / (torch.abs(out2) + 1e-7)
err2 = torch.abs(out3 - out2)
relerr2 = err2 / (torch.abs(out2) + 1e-7)
errs.append(err.mean().item())
relerrs.append(relerr.mean().item())
errs2.append(err2.mean().item())
relerrs2.append(relerr2.mean().item())
# print(mean(errs))
# print(mean(relerrs))
# print(mean(errs2))
# print(mean(relerrs2))
assert mean(errs) < 0.015
assert mean(relerrs) < 0.3
n = 2
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
dim4 = torch.randint(32, 256, size=(n,)).tolist()
transpose = [(False, False), (True, False), (False, True), (True, True)]
values = list(product(dim1, dim2, dim3, dim4, transpose))
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals)
for vals in values
]
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names)
def test_ibmm(dim1, dim2, dim3, dim4, transpose):
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
for i in range(k):
shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
out2 = torch.bmm(A.float(), B.float())
out = F.igemm(A, B)
elif not transpose[0] and transpose[1]:
out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
out = F.igemm(A, B.permute([0, 2, 1]))
elif transpose[0] and not transpose[1]:
out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
out = F.igemm(A.permute([0, 2, 1]), B)
elif transpose[0] and transpose[1]:
out2 = torch.bmm(
A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
)
out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
torch.testing.assert_allclose(out.float(), out2.float())
n = 1
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3))
names = ["dim1_{0}_dim2_{1}_dim3_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
def test_vector_quant(dim1, dim2, dim3):
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
for i in range(k):
A = torch.randn(size=(dim2, dim3), device="cuda")
qA, SA = F.vectorwise_quant(A, dim=0)
A1 = F.vectorwise_dequant(qA, SA)
n = A1.numel()
assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002))
n = 2
dim1 = torch.randint(2, 256, size=(n,)).tolist()
dim2 = torch.randint(2, 256, size=(n,)).tolist()
dim3 = torch.randint(2, 256, size=(n,)).tolist()
# dim1, dim2 = (256,), (256,)
dtype = [torch.int8, torch.int32]
a_order = ["row"]
out_order = ["col", "row", "col32"]
transpose = [False]
dims = [2, 3]
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
names = ["dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(*vals)for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names)
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
if dims == 3 and out_order != "col32":
return
if dtype == torch.int32 and out_order != "col32":
return
func = F.get_transform_func(dtype, orderA, orderOut, transpose)
if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
elif dims == 3:
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
dtype
)
out, S = F.nvidia_transform(A, to_order=orderOut)
if orderOut == "row":
torch.testing.assert_allclose(A.flatten(), out.flatten())
elif orderOut == "col":
torch.testing.assert_allclose(A.t().flatten(), out.flatten())
elif orderOut == "col32":
if dims == 2:
n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
elif dims == 3:
n = (
A.shape[0]
* A.shape[1]
* (A.shape[2] + (32 - (A.shape[2] % 32)))
)
assert out.numel() == n
elif orderOut == "col_turing":
# 32 col 8 row tiles
n = (A.shape[0] + (8 - A.shape[0] % 8)) * (
A.shape[1] + (32 - (A.shape[1] % 32))
)
assert out.numel() == n
total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
for row in range(A.shape[0]):
for col in range(A.shape[1]):
i = row * A.shape[1]
j = col
coltile = (col // 32) + (1 if col % 32 != 0 else 0)
rowtile = (
(row // 8) + (1 if row % 8 != 0 else 0)
) * total_coltile
offset = 32 * 8 * (rowtile + coltile)
col2 = col % 32
row2 = (row % 8) * 32
assert A.flatten()[i + j] == A[row, col]
# assert A.flatten()[i+j] == out.flatten()[row2+col2]
# torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
# torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if orderOut == "col32":
out2, S = F.nvidia_transform(
out, from_order=orderOut, to_order="row", state=S
)
torch.testing.assert_allclose(A, out2)
n = 1
dim1 = torch.randint(1, 256, size=(n,)).tolist()
dim2 = torch.randint(32, 512, size=(n,)).tolist()
dim3 = torch.randint(32, 1024, size=(n,)).tolist()
dim4 = torch.randint(32, 1024, size=(n,)).tolist()
# dim1 = [2]
# dim2 = [2]
# dim3 = [2]
# dim4 = [2]
dims = (2, 3)
ldb = [0]
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}".format(*vals)
for vals in values
]
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names)
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
for i in range(k):
if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
torch.int8
)
elif dims == 3:
A = torch.randint(
-128, 127, size=(dim1, dim2, dim3), device="cuda"
).to(torch.int8)
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
torch.int8
)
C1 = torch.matmul(A.float(), B.t().float())
A2, SA = F.transform(A, "col32")
B2, SB = F.transform(B, "col_turing")
C2, SC = F.igemmlt(A2, B2, SA, SB)
C3, S = F.nvidia_transform(C2, "row", state=SC)
torch.testing.assert_allclose(C1, C3.float())
# transpose
B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
torch.int8
)
C1 = torch.matmul(A.float(), B.float())
B2t, SBt = F.transform(B, "col_turing", transpose=True)
C2, SC = F.igemmlt(A2, B2t, SA, SBt)
C3, S = F.nvidia_transform(C2, "row", state=SC)
torch.testing.assert_allclose(C1, C3.float())
dim1 = [32]
dim2 = [32]
dim3 = [32]
dim4 = [32]
dims = (2,)
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims))
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals)
for vals in values
]
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names)
def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
formatB = F.get_special_format_str()
for i in range(k):
if dims == 2:
A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
elif dims == 3:
A = torch.normal(
0, 0.5, size=(dim1, dim2, dim3), device="cuda"
).half()
B = torch.randn((dim4, dim3), device="cuda").half()
torch.nn.init.xavier_uniform_(B)
C1 = torch.matmul(A, B.t())
C2 = bnb.matmul(A, B.t())
A = A.view(-1, A.shape[-1])
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
C32A, SA = F.transform(CA, "col32")
CxB, SB = F.transform(CB, to_order=formatB)
out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB)
output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt)
# print('')
# print(output.flatten()[:10])
# print(C1.flatten()[:10])
# print(C2.flatten()[:10])
# torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
# transpose
# B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
# C1 = torch.matmul(A.float(), B.float())
# B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
# C2, SC = F.igemmlt(A2, B2t, SA, SBt)
# C3, S = F.transform(C2, 'row', state=SC)
# torch.testing.assert_allclose(C1, C3.float())
batch_size = 2
seqdim = 512
# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
values = [
(batch_size, seqdim, 4 * 1024, 3 * 4 * 1024),
(batch_size, seqdim, 5120, 3 * 5120),
(batch_size, seqdim, 12 * 1024, 4 * 12 * 1024),
]
# values = list(product(batch, seq, model, hidden))
names = [
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_8bit_training(batch, seq, model, hidden):
formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device="cuda").half()
grad = torch.randn(batch, seq, model, device="cuda").half()
w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
print("")
# torch.cuda.synchronize()
## warmup
# for i in range(100):
# torch.matmul(A, w1.t())
# torch.cuda.synchronize()
dtype = torch.int8
A = A.view(-1, A.shape[-1]).contiguous()
grad = grad.view(-1, grad.shape[-1]).contiguous()
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
out1 = torch.matmul(A, w1.t()) # fc1
# out2 = torch.matmul(out1, w2.t())# fc2
# d1 = torch.matmul(grad, w2) # delta1
# d2 = torch.matmul(d1, w1) # delta2
# grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
# grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
torch.cuda.synchronize()
t16 = time.time() - t0
print(t16)
# torch.cuda.empty_cache()
# Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
# CTw1, Sw1 = F.transform2(Cw1, formatB)
# CTw2, Sw2 = F.transform2(Cw2, formatB)
# CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
# CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
# CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
# C32A, SA = F.transform2(CA, 'col32')
## fc1
# out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)
## fc2
# Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
# C32out1, Sout1 = F.transform2(Cout1, 'col32')
# out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)
## delta1
# Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
# C32grad, Sgrad = F.transform2(Cgrad, 'col32')
##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)
## delta2
# Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
# C32d1, Sd1 = F.transform2(Cd1, 'col32')
##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)
## grad1
# C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
# CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)
## grad2
# C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
# CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
# Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
# CTw1, Sw1 = F.transform2(Cw1, formatB)
# CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
# CTw2, Sw2 = F.transform2(Cw2, formatB)
# CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(k):
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
# #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
# CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
# #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
# #CTw2, Sw2 = F.transform2(Cw2, formatB)
# #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
# C32A, SA = F.transform2(CA, 'col32')
# # fc1
# out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
# #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
# #print(coo_tensor.nnz)
# #out1sp = F.spmm_coo(coo_tensor, w1.t())
# #print(w1.t().shape)
# #out1 = out1dn + out1sp
# # fc2
# Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
# C32out1, Sout1 = F.transform2(Cout1, 'col32')
# out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
# #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)
# # delta1
# Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
# C32grad, Sgrad = F.transform2(Cgrad, 'col32')
# d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
# #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)
# # delta2
# Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
# C32d1, Sd1 = F.transform2(Cd1, 'col32')
# d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
# #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)
# # grad1
# #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
# #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
# #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
# #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)
# ## grad2
# #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
# #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
# #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
# #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)
# torch.cuda.synchronize()
# t8 = time.time() - t0
# print(t8)
n = 2
dim1 = torch.randint(64, 256, size=(n,)).tolist()
dim4 = torch.randint(64, 1024, size=(n,)).tolist()
#dim1 = [2*1024]
#dim4 = [2*1024]
#dim1 = [4]
#dim4 = [4]
dims = (2,)
formatB = ["col_turing", "col_ampere"]
has_bias = [True, False]
values = list(product(dim1, dim4, dims, formatB, has_bias))
names = ["dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}_has_bias_{4}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
inner = torch.randint(1, 128, size=(1,)).item()
bias = None
if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16)
formatB = F.get_special_format_str()
for i in range(1):
A = torch.randn(dim1, inner, device="cuda")
B = torch.randn(dim4, inner, device="cuda")
C1 = torch.matmul(A.half(), B.t().half())
if has_bias: C1 += bias
A1, maxA = F.vectorwise_quant(A, dim=1)
B1, maxB = F.vectorwise_quant(B, dim=1)
A2, SA = F.nvidia_transform(A1, "col32")
B2, SB = F.nvidia_transform(B1, formatB)
C2, SC = F.igemmlt(A2, B2, SA, SB)
C3, S = F.nvidia_transform(C2, "row", state=SC)
C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
if has_bias: C4 += bias
# TODO: is something wrong here? If so, the problem goes deeper
#n = C1.numel()
#p = 0.06
std = C1.std(0).view(1, -1)
C1 /= std
C4 /= std
#assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06))
#assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias)
#torch.testing.assert_allclose(C5, C4, atol=0.015, rtol=0.1)
n = C5.numel()
assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n))
n = 2
dim1 = [1 * 1024]
dim2 = [1 * 1024]
# dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dims = (2,)
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dims))
names = ["dim1_{0}_dim2_{1}_dims_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
def test_colrow_absmax(dim1, dim2, dims):
for i in range(k):
threshold = 3.0
A = torch.randn(dim1, dim2, device="cuda").half()
A_truncated = A.clone()
A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0
if dims == 2:
row_stats1, _ = torch.abs(A.float()).max(1)
col_stats1, _ = torch.abs(A.float()).max(0)
row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
else:
assert False
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
A, threshold=threshold
)
A_blocked = einops.rearrange(
torch.abs(A),
"(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size",
row_tiles=16,
block_size=64 * 4,
)
nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten()
nnz_block_ptr1 = torch.zeros(
nnz_rows1_counts.shape[0] + 1,
dtype=nnz_rows1_counts.dtype,
device=nnz_rows1_counts.device,
)
nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)
torch.testing.assert_allclose(col_stats1_trunc, col_stats2)
torch.testing.assert_allclose(row_stats1_trunc, row_stats2)
torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2)
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
A, threshold=0.0
)
torch.testing.assert_allclose(col_stats1, col_stats2)
torch.testing.assert_allclose(row_stats1, row_stats2)
assert nnz_block_ptr2 is None
n = 2
# dim1 = [8*1024]
# dim2 = [4*1024]
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_double_quant(dim1, dim2):
for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half()
out_col1, Scol = F.vectorwise_quant(A, dim=0)
out_row1, Srow = F.vectorwise_quant(A, dim=1)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
# max difference is 1 due to rounding differences
torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0)
torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0)
n = CAt.numel()
num_not_close_rows = (
(torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
)
num_not_close_cols = (
(torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
)
# allow for 1:500 error due to rounding differences
min_error = 1 / 500
if num_not_close_cols > (min_error * n):
print(
f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
)
assert False
if num_not_close_rows > (min_error * n):
print(
f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
)
assert False
torch.testing.assert_allclose(Srow.flatten(), statsA)
torch.testing.assert_allclose(Scol.flatten(), statsAt)
n = 4
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner))
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
def test_integrated_igemmlt(dim1, dim4, inner):
for i in range(k):
A = torch.randn(dim1, inner, device="cuda").half()
B = torch.randn(dim4, inner, device="cuda").half()
out1 = torch.matmul(A.half(), B.t().half())
C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
A1, maxA = F.vectorwise_quant(A, dim=1)
B1, maxB = F.vectorwise_quant(B, dim=1)
torch.testing.assert_allclose(maxA.flatten(), stats1a)
torch.testing.assert_allclose(maxB.flatten(), stats2a)
torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1)
torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1)
A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(C2a, "col_turing")
outC32, SC = F.igemmlt(A2, B2, SA, SB)
out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)
A2, SA = F.nvidia_transform(A1, "col32")
B2, SB = F.nvidia_transform(B1, "col_turing")
C2, SC = F.igemmlt(A2, B2, SA, SB)
C3, S = F.nvidia_transform(C2, "row", state=SC)
out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
err1 = torch.abs(out1 - out2).mean().item()
err2 = torch.abs(out1 - out3).mean().item()
assert err2 <= err1 * 1.025
n = 6
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner))
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_igemmlt_row_scale(dim1, dim4, inner):
formatB = F.get_special_format_str()
err1, err2, err3 = [], [], []
relerr1, relerr2 = [], []
scale = 1
for i in range(k):
A = torch.randn(dim1, inner, device="cuda").half()
B = torch.randn(dim4, inner, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
C1 = torch.matmul(A, B.t())
out1 = torch.matmul(A.half(), B.t().half())
C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(CB, formatB)
A1, maxA = F.vectorwise_quant(A, dim=1)
c = 10.0 * inner * scale
row_scale = torch.ones_like(maxA) / c
outC32, SC = F.igemmlt(
A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
)
C3, S = F.nvidia_transform(outC32, "row", state=SC)
maxval = torch.abs(C3).max()
if maxval == 127:
scale = 1.5
else:
scale = maxval / 120
out3 = C3 * maxA * absmaxB * c / (127 * 127)
C4 = torch.matmul(C1a.float(), CB.float().t())
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
B2, SB = F.nvidia_transform(C2a, formatB)
outC32, SC = F.igemmlt(A2, B2, SA, SB)
out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)
CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
C = torch.matmul(CA.float(), CB.t().float())
out4 = C * SA * SB / (127 * 127)
# out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
# print('='*80)
# print(out1)
# print(out2)
# print(out3)
# print(out1)
# print(out2)
# print(out3)
err1.append(torch.abs(out1 - out2).mean().item())
err2.append(torch.abs(out1 - out3).mean().item())
err3.append(torch.abs(out1 - out4).mean().item())
# assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
print("")
print(sum(err1) / len(err1))
print(sum(err2) / len(err2))
print(sum(err3) / len(err3))
dim1 = [1024, 2048]
inner = [12288 * 4, 4096 * 4]
dim4 = [12288, 4096]
values = list(zip(dim1, dim4, inner))
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_row_scale_bench(dim1, dim4, inner):
err1, err2, err3 = [], [], []
relerr1, relerr2 = [], []
scale = 1
A = torch.randn(dim1, inner, device="cuda").half()
B = torch.randn(dim4, inner, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
# warmpup
for i in range(k):
C1 = torch.matmul(A, B.t())
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
C1 = torch.matmul(A, B.t())
torch.cuda.synchronize()
print("16", time.time() - t0)
C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(CB, formatB)
A1, maxA = F.vectorwise_quant(A, dim=1)
c = 10.0 * inner * scale
row_scale = maxA / c
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
outC32, SC = F.igemmlt(
A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
)
torch.cuda.synchronize()
print("row-wise", time.time() - t0)
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
B2, SB = F.nvidia_transform(C2a, formatB)
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
outC32, SC = F.igemmlt(A2, B2, SA, SB)
torch.cuda.synchronize()
print("vector-wise", time.time() - t0)
n = 2
dim1 = torch.randint(2, 1024, size=(n,)).tolist()
dim2 = torch.randint(2, 1024, size=(n,)).tolist()
# dim1 = [8*1024]
# dim2 = [4*1024]
dim3 = [0]
dtype = [torch.int8]
a_order = ["row"]
out_order = ["col32", "col_turing", "col_ampere"]
transpose = [False, True]
dims = [2]
values = list(
product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
)
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format(
*vals
)
for vals in values
]
@pytest.mark.parametrize(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
values,
ids=names,
)
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
for i in range(k):
if dims == 2:
A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(
dtype
)
elif dims == 3:
A = torch.randint(
10, 99, size=(dim1, dim2, dim3), device="cuda"
).to(dtype)
A.view(-1)[-1] = -1
if transpose:
At = A.t().contiguous()
out1, S1 = F.nvidia_transform(At, to_order=orderOut)
else:
out1, S1 = F.nvidia_transform(A, to_order=orderOut)
out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)
assert S1[0][0] == S2[0][0]
assert S1[0][1] == S2[0][1]
# print(out1)
# print(out2)
torch.testing.assert_allclose(out1, out2)
n = 2
# dim1 = torch.randint(2,1024, size=(n,)).tolist()
# dim2 = torch.randint(2,1024, size=(n,)).tolist()
dim1 = [1]
dim2 = [33]
dtype = [torch.int8]
# a_order = ['col_turing', 'col_ampere']
a_order = ["col_turing"]
out_order = ["row"]
values = list(product(dim1, dim2, dtype, a_order, out_order))
names = [
"dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}".format(*vals)
for vals in values
]
def test_overflow():
formatB = F.get_special_format_str()
print(formatB)
for i in range(2):
a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
Ca, Sa = F.nvidia_transform(a, "col32")
Cb, Sb = F.nvidia_transform(b, formatB)
c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8)
c2 = torch.matmul(a.float(), b.float().t())
n = 2
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
# dim1 = [4]
# dim2 = [5]
values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_coo_double_quant(dim1, dim2):
threshold = 3.00
for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half()
idx = torch.abs(A) >= threshold
CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
A, threshold=threshold
)
if coo_tensor is not None:
A1 = A * idx
A2 = torch.zeros_like(A)
A2[
coo_tensor.rowidx.long(), coo_tensor.colidx.long()
] = coo_tensor.values
torch.testing.assert_allclose(A1, A2)
A1 = A * (idx == 0)
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
torch.testing.assert_allclose(
A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
)
n = 2
dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
# dim1 = [7]
# dim2 = [11]
transposed_B = [False, True]
values = list(product(dim1, dim2, transposed_B))
names = ["dim1_{0}_dim2_{1}_transposed_B_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
def test_spmm_coo(dim1, dim2, transposed_B):
threshold = 1.5
dim3 = torch.randint(32, 128, size=(1,)).item()
# dim3 = 17
for i in range(k):
A = torch.randn(dim1, dim2).cuda().half()
if transposed_B:
B = torch.randn(dim3, dim2).cuda().half()
else:
B = torch.randn(dim2, dim3).cuda().half()
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx
if transposed_B:
out2 = F.spmm_coo(cooA, B.t())
out1 = torch.matmul(A2, B.t())
else:
out2 = F.spmm_coo(cooA, B)
out1 = torch.matmul(A2, B)
assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)
def test_spmm_bench():
batch = 2
model = 1024 * 1
hidden = model * 4
seq = 1024
dim1 = batch * seq
dim2 = model
dim3 = hidden
threshold = 4
A = torch.randn(dim1, dim2, device="cuda").half()
B = torch.randn(dim2, dim3, device="cuda").half()
for i in range(10):
C1 = bnb.matmul(A, B.t())
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
C1 = bnb.matmul(A, B.t())
torch.cuda.synchronize()
t8 = time.time() - t0
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
print(nnz / idx.numel())
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
for i in range(10):
out2 = F.spmm_coo(cooA, B)
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
out2 = F.spmm_coo(cooA, B)
torch.cuda.synchronize()
tsp = time.time() - t0
print(tsp, t8)
print(tsp / t8)
n = 2
dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_integrated_sparse_decomp(dim1, dim2):
threshold = 3.0
formatB = "col_turing"
for i in range(k):
A = torch.randn(dim1, dim2).cuda().half()
w1 = torch.randn(dim1, dim2).cuda().half()
out1 = torch.matmul(A, w1.t())
Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
CTw1, Sw1 = F.transform(Cw1, formatB)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
C32A, SA = F.transform(CA, "col32")
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
A, threshold=threshold
)
C32A, SA = F.transform(CA, "col32")
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
assert coo_tensor is not None
out4 = F.spmm_coo(coo_tensor, w1.t())
out5 = out3 + out4
err1 = torch.abs(out1 - out2).mean().item()
err2 = torch.abs(out1 - out5).mean().item()
assert err2 < err1
def test_matmuls():
a = torch.randn(256, 512).half().cuda()
b = torch.randn(256, 512).half().cuda()
c1 = torch.matmul(a, b.t())
c2 = bnb.matmul(a, b)
c3 = bnb.matmul_cublas(a, b.t())
err1 = torch.abs(c1 - c2).mean().item()
err2 = torch.abs(c1 - c3).mean().item()
assert err1 < 0.2
assert err2 < 0.2
print(err1, err2)
n = 2
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
dim2 = [12288]
# dim1 = [32]
# dim2 = [32]
# dtype = [torch.float16, torch.int8]
dtype = [torch.float16]
out_function = ["zeros", "ones"]
values = list(product(dim1, dim2, dtype, out_function))
names = [
"dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values
]
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
out_func = getattr(torch, out_func)
threshold = 3.3
# threshold = 2.8
# threshold = 0.0
A = torch.randn(dim1, dim2, device="cuda").half()
if dtype == torch.float16:
B = torch.randn(dim2, dim2 * 4, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
else:
B = torch.randn(dim2, dim2 * 4, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
B, SB = F.vectorwise_quant(B, quant_type="linear")
# B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
print("")
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx
out1 = torch.matmul(A2.half(), B.half())
out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
out1 += out.clone()
out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
# print(B)
# print(out1)
# print(out2)
p = 200 / (2048 * 12288 * 4)
n = out1.numel()
count = math.ceil(p * n)
std = out1.std()
out1 /= std
out2 /= std
assert_all_approx_close(
out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count
)
# assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
idx_col = torch.randint(0, A2.shape[-1], size=(15,))
# torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
# Bt = torch.randn(dim2*4, dim2, device='cuda').half()
# torch.cuda.synchronize()
# t0 = time.time()
# print(A2.shape, B.shape)
# for i in range(100):
# #out3 = F.spmm_coo(cooA, Bt.t())
# #out2 = F.spmm_coo(cooA, B)
# #out2 = F.spmm_coo_very_sparse(cooA, B)
# #out1 = torch.matmul(A, Bt.t())
# torch.cuda.synchronize()
# print(time.time() - t0)
def test_coo2csr():
threshold = 1
A = torch.randn(128, 128).half().cuda()
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx
csrA = F.coo2csr(cooA)
counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
assert counts.numel() == A.shape[0]
torch.testing.assert_allclose(counts, (A2 != 0).sum(1))
idx = A2 != 0
torch.testing.assert_allclose(A2[idx], csrA.values)
def test_coo2csc():
threshold = 1
A = torch.randn(128, 128).half().cuda()
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx
cscA = F.coo2csc(cooA)
counts = cscA.colptr[1:] - cscA.colptr[:-1]
assert counts.numel() == A.shape[1]
torch.testing.assert_allclose(counts, (A2 != 0).sum(0))
# torch uses row-major -> use transpose to transfer to col-major
idx = A2.t() != 0
torch.testing.assert_allclose(A2.t()[idx], cscA.values)
n = 2
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
# dim2 = [12288]
dim2 = [2048]
# dim1 = [2]
# dim2 = [2]
dtype = [torch.int8]
values = list(product(dim1, dim2, dtype))
names = ["dim1_{0}_dim2_{1}_dtype_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
def test_spmm_coo_dequant(dim1, dim2, dtype):
threshold = 6.0
# threshold = 2.8
# threshold = 0.0
A = torch.randn(dim1, dim2, device="cuda").half()
B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
torch.nn.init.xavier_uniform_(B)
Bt = B.t().contiguous()
CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
rowidx = torch.randint(0, A.shape[-1], size=(15,))
A[:, rowidx] = 8.0
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
out1 = torch.matmul(A2, B.half())
out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
out3 = out3 * statsBt.half() / 127
values, counts = torch.unique(cooA.rowidx, return_counts=True)
offset = counts.cumsum(0).int()
max_count, max_idx = torch.sort(counts, descending=True)
print(torch.median(max_count.float()))
torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001)
p = 200 / (2048 * 12288 * 4)
n = out1.numel()
count = math.ceil(p * n)
assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(100):
# out2 = F.spmm_coo_very_sparse(cooA, B)
# torch.cuda.synchronize()
# print('fp16', time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out2 = F.spmm_coo(cooA, B)
torch.cuda.synchronize()
print("cusparse fp16", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out2 = F.spmm_coo_very_sparse(cooA, CBt)
torch.cuda.synchronize()
print("int8", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
torch.cuda.synchronize()
print("int8+dequant", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out2 = torch.matmul(A, B)
torch.cuda.synchronize()
print("matmul", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out1 = bnb.matmul(A, Bt)
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
out = out1 + out2
torch.cuda.synchronize()
print("sparse+ matmul", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out1 = bnb.matmul(A, Bt)
torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
torch.cuda.synchronize()
print("partial matmul", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out1 = bnb.matmul(A, Bt)
torch.cuda.synchronize()
print("partial matmul", time.time() - t0)
batch_size = 1
seqdim = 1
values = []
values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 1024, 4*1024))
# values.append((batch_size, seqdim, 1536, 4*1536))
# values.append((batch_size, seqdim, 2048, 4*2048))
# values.append((batch_size, seqdim, 2560, 4*2560))
# values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288))
names = [
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden):
iters = 128
formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device="cuda").half()
B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
torch.nn.init.xavier_uniform_(B)
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
linear8bit.eval()
outliers = torch.randint(0, model, size=(5,)).cuda()
A[:, :, outliers] = 8.0
linearMixedBit = (
bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
)
linearMixedBit.eval()
# warmup
for i in range(iters):
torch.matmul(A, B.t())
torch.cuda.synchronize()
print("")
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
torch.matmul(A, B.t())
torch.cuda.synchronize()
print(
f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul(A, B)
torch.cuda.synchronize()
print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul(A, B, threshold=6.0)
torch.cuda.synchronize()
print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
C32A, SA = F.transform(CA, "col32")
CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
CxB, SB = F.transform(CB, to_order=formatB)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
torch.cuda.synchronize()
print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
BA, statsB = F.vectorwise_quant(B, dim=1)
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
A2 = A.view(-1, A.shape[-1]).contiguous()
CA, statsA = F.vectorwise_quant(A2, dim=1)
C32A, SA = F.nvidia_transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
torch.cuda.synchronize()
#print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
A2 = A.view(-1, A.shape[-1]).contiguous()
CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
C32A, SA = F.nvidia_transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
out = Cout * statsB * statsA * (1.0 / (127 * 127))
torch.cuda.synchronize()
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linear8bit(A)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
linear8bit(A)
torch.cuda.synchronize()
print(
f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
linearMixedBit(A)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
linearMixedBit(A)
torch.cuda.synchronize()
print(
f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
def test_zeropoint():
def quant_zp(x):
dtype = x.dtype
x = x.float()
dyna = x.max() - x.min()
if dyna == 0:
dyna = 1
qx = 254.0 / dyna
minx = x.min()
# zpx = torch.round(minx* qx)
# zpx = 127 - torch.round(x.max()* qx)
zpx = torch.round(x.min() * qx) - 127
x = (qx * x) + zpx
return x, qx, zpx
batch = 2
seq = 512
model = 1024
hidden = 4 * model
A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
B = torch.randn(model, hidden, device="cuda").half() * 0.1
C0 = torch.matmul(A, B)
# A, SA = F.vectorwise_quant(A, quant_type='linear')
# B, SB = F.vectorwise_quant(B, quant_type='linear')
A = A.float()
B = B.float()
C1 = torch.matmul(A, B)
C3 = bnb.matmul(A.half(), B.t().contiguous().half())
zp = 1
# C2 = torch.matmul(A-zp, B)
# C2 += B.sum(0).view(1, -1)*zp
C2 = torch.matmul(A, B - zp)
C2 -= A.sum(1).view(-1, 1) * zp
ca, cqa, cza = quant_zp(A)
print(ca.min(), ca.max())
print((ca - cza).min(), (ca - cza).max())
zp = 1
scale = 2.0
C5 = torch.matmul((A * scale) - zp, B)
C5 += B.sum(0) * zp
C5 /= scale
CA, qa, zpa = quant_zp(A)
C4 = torch.matmul(CA, B)
C4 -= B.sum(0) * zpa
C4 /= qa
zpb = 1
zpa = 1
qa = 2
qb = 2
C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
C6 -= zpa * zpb * A.shape[1]
C6 /= qa * qb
CA, qa, zpa = quant_zp(A)
CB, qb, zpb = quant_zp(B)
C7 = torch.matmul(CA, CB)
C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
C7 -= zpa * zpb * A.shape[1]
C7 /= qa * qb
print("")
# print(C0.flatten()[:10])
print(C1.flatten()[:10])
print(C2.flatten()[:10])
print(C3.flatten()[:10])
print(C5.flatten()[:10])
print(C6.flatten()[:10])
print(C7.flatten()[:10])
err1 = torch.abs(C1 - C2).mean().item()
err2 = torch.abs(C1 - C3).mean().item()
err3 = torch.abs(C1 - C4).mean().item()
err4 = torch.abs(C1 - C5).mean().item()
err5 = torch.abs(C1 - C6).mean().item()
err6 = torch.abs(C1 - C7).mean().item()
print(err1, err2, err3, err4, err5, err6)
def test_extract_outliers():
for i in range(k):
shapeA = (4096, 4096 * 4)
idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
# idx = torch.Tensor([0]).int().cuda()
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
outliers1 = A[:, idx.long()]
CA, SA = F.transform(A, "col_turing")
outliers2 = F.extract_outliers(CA, SA, idx)
assert outliers2.shape[0] == shapeA[0]
assert outliers2.shape[1] == idx.numel()
torch.testing.assert_allclose(outliers1, outliers2)
CA, SA = F.transform(A, "col_ampere")
outliers2 = F.extract_outliers(CA, SA, idx)
assert outliers2.shape[0] == shapeA[0]
assert outliers2.shape[1] == idx.numel()
torch.testing.assert_allclose(outliers1, outliers2)
def test_blockwise_cpu_large():
diffs = []
reldiffs = []
batch = 128
seq = 128
for hidden in [128]:#, 14336]:
for blocksize in [4096, 16384]:
for i in range(2):
A1 = torch.randn(batch, seq, hidden, device='cpu')
t0 = time.time()
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
print(time.time() - t0)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diffs[-1] < 0.011
# print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs))
def test_fp8_quant():
for e_bits in range(1, 7):
p_bits = 7-e_bits
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
print(e_bits, p_bits)
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
#print(sum(abserr)/len(abserr))
#print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
#print(sum(abserr)/len(abserr))
#print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
#print(3, sum(abserr)/len(abserr))
#print(3, sum(relerr)/len(relerr))
def test_few_bit_quant():
#print('')
for bits in range(2, 9):
#print('='*30, bits, '='*30)
for method in ['linear', 'fp8', 'dynamic', 'quantile']:
abserrs = []
relerrs = []
code = None
if method == 'linear':
code = F.create_linear_map(True, total_bits=bits).cuda()
elif method == 'fp8':
ebits = math.ceil(bits/2)
pbits = bits-ebits-1
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
elif method == 'dynamic':
code = F.create_dynamic_map(True, bits-0, bits).cuda()
elif method == 'quantile':
values = torch.randn(2048, 2048, device='cuda')
code = F.create_quantile_map(values, bits).cuda()
# for some data types we have no zero
# for some data types we have one zero
# for some data types we have two zeros
assert torch.unique(code).numel() in [2**bits, 2**bits-1], f'bits: {bits}, method: {method}'
#print(method, (code==0).sum())
assert code.numel() == 256
for i in range(10):
values = torch.randn(1, 32, device='cuda')
values /= values.abs().max()
#values[values.abs() < 1e-6] += 1e-5
q1 = []
v1 = []
for v in values[0]:
idx = torch.abs(v-code).argmin()
q1.append(idx.item())
v1.append(code[idx].item())
q1 = torch.Tensor(q1).cuda()
v1 = torch.Tensor(v1).cuda()
q2, S2 = F.quantize_blockwise(values, code=code)
v2 = F.dequantize_blockwise(q2, S2)
idx = torch.isclose(q1.int(), q2.int())
err2 = torch.abs(v2-values)
abserrs.append(err2.mean().item())
relerrs.append((err2/(1e-10+values).abs()).mean().item())
if idx.sum():
# some weird cases
err1 = torch.abs(v1-values).mean()
#assert err2.mean() <= err1
else:
torch.testing.assert_allclose(q1, q2)
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
#assert False
def test_kbit_quantile_estimation():
for i in range(100):
data = torch.randn(1024, 1024, device='cuda')
for bits in range(2, 9):
p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
err = torch.abs(val1-val2).mean()
assert err < 0.038
for i in range(100):
data = torch.randn(1024, 1024, device='cuda')
for bits in range(2, 4):
total_values = 2**bits-1
p = np.linspace(0, 1, 2*total_values+1)
idx = np.arange(1, 2*total_values+1, 2)
p = p[idx]
offset = 1/(2*total_values)
p = np.linspace(offset, 1-offset, total_values)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, num_quantiles=2**bits-1)
err = torch.abs(val1-val2).mean()
assert err < 0.035
def test_bench_dequantization():
a = torch.rand(1024, 1024, device='cuda').half()
qa, SA = F.quantize_blockwise(a)
max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000
#print(max_theoretical_mu)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
F.dequantize_blockwise(qa, SA, blocksize=2048)
torch.cuda.synchronize()
#print((time.time()-t0)/1e6)
from itertools import product
import pytest
import torch
from torch import nn
import bitsandbytes as bnb
class MockArgs(object):
def __init__(self, initial_data):
for key in initial_data:
setattr(self, key, initial_data[key])
class MLP8bit(torch.nn.Module):
def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
super(MLP8bit, self).__init__()
self.fc1 = bnb.nn.Linear8bitLt(
dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
threshold=threshold
)
self.fc2 = bnb.nn.Linear8bitLt(
dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
threshold=threshold
)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
def get_args():
args = MockArgs([])
args.quant_type = "vector"
args.use_8bit_training = "full"
args.clip_freq = 9999
return args
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
idx = torch.isclose(a, b, rtol, atol)
sumval = (idx == 0).sum().item()
if sumval > count:
print(f"Too many values not close: assert {sumval} < {count}")
torch.testing.assert_allclose(a, b, rtol, atol)
class LinearFunction(torch.autograd.Function):
@staticmethod
def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
round_func = (
LinearFunction.round_stoachastic if stochastic else torch.round
)
norm = math.sqrt(math.pi) / math.sqrt(2.0)
# std = torch.abs(x).mean()*norm
std = torch.std(x)
max1 = std * trim_value
x = x / max1 * 127
x = round_func(x)
x[x > 127] = 127
x[x < -127] = -127
x = x / 127 * max1
return x
def quant(x, quant_type, dim=1):
if quant_type == "linear":
max1 = torch.abs(x).max().float()
xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1
elif quant_type == "vector":
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1
elif quant_type == "min-max":
maxA = torch.amax(x, dim=dim, keepdim=True).float()
minA = torch.amin(x, dim=dim, keepdim=True).float()
scale = (maxA - minA) / 2.0
xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8)
return xq, (minA.float(), scale.float())
else:
return None
def dequant(xq, S1, S2, dtype, quant_type):
if quant_type == "linear":
norm = S1 * S2 / (127 * 127)
# double cast needed to prevent overflows
return (xq.float() * norm).to(dtype)
elif quant_type == "vector":
x = xq.float()
if len(xq.shape) == 2 and len(S1.shape) == 3:
S1 = S1.squeeze(0)
if len(xq.shape) == 2 and len(S2.shape) == 3:
S2 = S2.squeeze(0)
# print(x.shape, S1.shape, S2.shape)
if len(S1.shape) == 2:
x *= S1.t() / 127
else:
x *= S1 / 127
x *= S2 / 127
return x.to(dtype)
else:
return None
def dequant_min_max(xq, A, B, SA, SB, dtype):
offset = B.float().t().sum(0) * (SA[0] + SA[1])
x = xq.float()
if len(xq.shape) == 2 and len(SB.shape) == 3:
SB = SB.squeeze(0)
if len(xq.shape) == 2 and len(SA.shape) == 3:
SA = SA.squeeze(0)
if len(SB.shape) == 2:
x *= SB.t() / 127
else:
x *= SB / 127
x *= SA[1] / 127
x += offset
return x.to(dtype)
def get_8bit_linear(x, stochastic=False):
round_func = (
LinearFunction.round_stoachastic if stochastic else torch.round
)
max1 = torch.abs(x).max()
x = x / max1 * 127
x = round_func(x) / 127 * max1
# x = torch.round(x)/128*max1
return x
@staticmethod
def get_8bit_vector_wise(x, dim, stochastic=False):
round_func = (
LinearFunction.round_stoachastic if stochastic else torch.round
)
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
max1[max1 == 0] = 1.0
x = (x * 127) / max1
x = round_func(x) / 127 * max1
return x
@staticmethod
def round_stoachastic(x):
sign = torch.sign(x)
absx = torch.abs(x)
decimal = absx - torch.floor(absx)
rdm = torch.rand_like(decimal)
return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype))
@staticmethod
def fake_8bit_storage(w, exponent_bits):
code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device)
absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
out = bnb.functional.dequantize_blockwise(absmax, C, code)
out = out.half()
w.copy_(out)
return out
@staticmethod
def fake_8bit_storage_quantile(w, args):
code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
# C = bnb.functional.quantize_no_absmax(code, w)
# out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
# print(out)
# out = out.half()
code /= torch.max(torch.abs(code))
absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
out = bnb.functional.dequantize_blockwise(absmax, C, code)
out = out.half()
w.copy_(out)
return out
@staticmethod
def fake_8bit_storage_stoachstic(w):
rand = torch.rand(1024, device=w.device)
absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand)
out = bnb.functional.dequantize_blockwise(absmax, C)
out = out.half()
w.copy_(out)
return out
@staticmethod
def fake_8bit_storage_with_max(w, topk=8):
blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256)
max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
idx = idx[:, :topk]
max_val = max_val[:, :topk]
mask = torch.zeros_like(blocked_w)
mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val))
mask = mask.bool()
# 1. zero out max values
# 2. quantize + dequantize
# 3. write back max values
# 4. copy matrix back to weight
values = blocked_w[mask]
blocked_w[mask] = 0
code = bnb.functional.create_dynamic_map()
code = code.to(w.device)
absmax, C = bnb.functional.quantize_blockwise(blocked_w.data)
bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w)
blocked_w[mask] = values
unblocked_w = blocked_w.flatten().view(w.shape)
w.copy_(unblocked_w)
return unblocked_w
@staticmethod
def forward(ctx, x, weight, bias=None, args=None):
if args.use_8bit_training != "off":
weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
outputq = bnb.functional.igemm(x8, weight8.t())
output = LinearFunction.dequant(
outputq, S1, S2, x.dtype, args.quant_type
)
# if torch.rand(1) < 0.01:
# output32 = torch.matmul(x, weight.t())
# err = torch.abs(output-output32).float()
# relerr = err/(torch.abs(output32).float()+1e-8)
# print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
else:
# output = torch.matmul(x, weight.t())
output = torch.einsum("bsi,oi->bso", x, weight)
ctx.save_for_backward(x, weight, bias)
ctx.args = args
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
x, weight, bias = ctx.saved_tensors
args = ctx.args
stochastic = False
grad_input = grad_weight = grad_bias = None
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
# weight and x are already 8bit
# -> transform grad_output to 8-bit
if args.use_8bit_training == "forward+wgrad":
grad_output8, S1 = LinearFunction.quant(
grad_output, args.quant_type, dim=[0, 1]
)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = bnb.functional.igemm(grad_output8, x8)
grad_weight = LinearFunction.dequant(
grad_weight8, S1, S2, grad_output.dtype, args.quant_type
)
# grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
grad_input = grad_output.matmul(weight)
elif args.use_8bit_training == "full":
grad_output8, S1 = LinearFunction.quant(
grad_output, args.quant_type, dim=[0, 1]
)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
grad_weight = LinearFunction.dequant(
grad_weight8, S1, S2, grad_output.dtype, args.quant_type
)
grad_output8, S1 = LinearFunction.quant(
grad_output, args.quant_type, dim=2
)
weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
grad_input8 = bnb.functional.igemm(grad_output8, weight8)
grad_input = LinearFunction.dequant(
grad_input8, S1, S3, grad_output.dtype, args.quant_type
)
else:
grad_input = grad_output.matmul(weight)
grad_weight = torch.einsum("bsi,bso->oi", x, grad_output)
return grad_input, grad_weight, grad_bias, None
class Linear8bit(nn.Module):
def __init__(self, input_features, output_features, bias=True, args=None):
super(Linear8bit, self).__init__()
self.input_features = input_features
self.output_features = output_features
self.args = args
self.weight = nn.Parameter(torch.empty(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.empty(output_features))
else:
self.register_parameter("bias", None)
torch.nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
torch.nn.init.zeros_(self.bias)
def forward(self, x):
self.args.training = self.training
return LinearFunction.apply(x, self.weight, self.bias, self.args)
threshold = [0.0, 3.0]
values = threshold
names = ["threshold_{0}".format(vals) for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_inference(threshold):
l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half()
assert l1.weight.device.type == "cuda"
assert l1.weight.dtype == torch.float16
l1.eval()
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1)
if i == 1:
assert l1.state.CxB is not None
def test_linear8bitlt_accumulated_gradient():
l1 = torch.nn.Sequential(
*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]
)
l2 = torch.nn.Sequential(
*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]
)
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
l2[1].bias = torch.nn.Parameter(l1[1].bias.clone())
opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001)
opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001)
acc_steps = 10
for i in range(10):
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1)
o2 = l2(b1)
loss1 = o1.mean()
loss2 = o2.mean()
loss1.backward()
loss2.backward()
if i == 2:
assert l1[0].state.CxB is not None
assert l1[1].state.CxB is not None
if i > 0 and i % acc_steps == 0:
opt1.step()
opt1.zero_grad(True)
opt2.step()
opt2.zero_grad(True)
assert_all_approx_close(
l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2
)
assert_all_approx_close(
l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2
)
# we do this copy because otherwise we have small divergences over time that add up
l1[0].weight.data.copy_(l2[0].weight.data)
l1[1].weight.data.copy_(l2[1].weight.data)
else:
torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad)
torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad)
threshold = [0.0, 2.0]
values = threshold
names = ["threshold_{0}".format(vals) for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names)
@pytest.mark.parametrize("memory_efficient_backward", [True, False])
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
l1 = (
bnb.nn.Linear8bitLt(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
.cuda()
.half()
)
assert l1.weight.dtype == torch.int8
l1.eval()
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1)
assert o1.dtype == torch.float16
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda()
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0:
assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
mlp = (
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
.cuda()
.half()
)
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0:
assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
mlp = (
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
.half()
.cuda()
)
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0:
assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
mlp = (
MLP8bit(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
.half()
.to("cuda")
)
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0:
assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == "cuda"
mlp = MLP8bit(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization,
mlp = mlp.cuda().half() # and this line triggers quantization
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0:
assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == "cuda"
if memory_efficient_backward:
b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
o1 = mlp(b1)
assert o1.dtype == torch.float16
assert o1.requires_grad
grad_proj = torch.randn_like(o1)
mlp.zero_grad()
(o1 * grad_proj).sum().backward()
grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
scale = grad_ref.abs().mean()
torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
assert (idx == 0).sum().item() <= b1.numel() * 0.005
def test_linear8bitlt_fp32_bias():
# casts model to fp16 -> int8 automatically
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda()
assert l1.weight.dtype == torch.int8
assert l1.bias.dtype == torch.float32
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
# casts bias to fp32
o1 = l1(b1)
assert l1.bias.dtype == torch.float16
# casts model to fp16 -> int8 automatically
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False, bias=False).cuda()
assert l1.weight.dtype == torch.int8
assert l1.bias is None
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1)
assert l1.bias is None
import ctypes
import os
import shutil
import time
import uuid
from itertools import product
from os.path import join
import pytest
import torch
import bitsandbytes as bnb
import bitsandbytes.functional as F
# import apex
k = 20
def get_temp_dir():
path = "/tmp/autoswap/{0}".format(str(uuid.uuid4()))
os.makedirs(path, exist_ok=True)
return path
def rm_path(path):
shutil.rmtree(path)
str2optimizers = {}
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers["momentum_pytorch"] = (
None,
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
bnb.optim.Adam,
)
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["lars"] = (
lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9),
)
str2optimizers["rmsprop"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["adam8bit"] = (
torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
)
str2optimizers["momentum8bit"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["lars8bit"] = (
lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9),
)
str2optimizers["adam8bit_blockwise"] = (
torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
)
str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
)
str2optimizers["rmsprop8bit_blockwise"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
)
str2statenames = {}
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lars"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")]
str2statenames["adam8bit"] = [
("exp_avg", "state1", "qmap1", "max1"),
("exp_avg_sq", "state2", "qmap2", "max2"),
]
str2statenames["lamb8bit"] = [
("exp_avg", "state1", "qmap1", "max1"),
("exp_avg_sq", "state2", "qmap2", "max2"),
]
str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["momentum8bit"] = [
("momentum_buffer", "state1", "qmap1", "max1")
]
str2statenames["momentum8bit_blockwise"] = [
("momentum_buffer", "state1", "qmap1", "absmax1")
]
str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit_blockwise"] = [
("square_avg", "state1", "qmap1", "absmax1")
]
dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16]
optimizer_names = ["adam", "momentum", "rmsprop", "lars"]
values = list(product(dim1, dim2, gtype, optimizer_names))
names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
p2 = p1.clone()
p1 = p1.float()
torch_optimizer = str2optimizers[optim_name][0]([p1])
bnb_optimizer = str2optimizers[optim_name][1]([p2])
if gtype == torch.float32:
atol, rtol = 1e-6, 1e-5
else:
atol, rtol = 1e-4, 1e-3
for i in range(k):
g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g.clone().float()
p2.grad = g.clone()
bnb_optimizer.step()
torch_optimizer.step()
for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_allclose(
torch_optimizer.state[p1][name1],
bnb_optimizer.state[p2][name2],
atol=atol,
rtol=rtol,
)
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
if i % (k // 5) == 0 and i > 0:
path = get_temp_dir()
torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
del bnb_optimizer
bnb_optimizer = None
bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path)
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_allclose(
torch_optimizer.state[p1][name1],
bnb_optimizer.state[p2][name2],
atol=atol,
rtol=rtol,
)
if gtype == torch.float16:
# the adam buffers should also be close because they are 32-bit
# but the paramters can diverge because they are 16-bit
# the difference grow larger and larger with each update
# --> copy the state to keep weights close
p1.data = p1.data.half().float()
p2.copy_(p1.data)
torch.testing.assert_allclose(p1.half(), p2)
if optim_name in ["lars", "lamb"]:
assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
values = list(product(dim1, dim2, gtype))
names = ["dim1_{0}_dim2_{1}_gtype_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
def test_global_config(dim1, dim2, gtype):
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
p2 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
p3 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
mask = torch.rand_like(p2) < 0.1
beta1 = 0.9
beta2 = 0.999
lr = 0.001
eps = 1e-8
bnb.optim.GlobalOptimManager.get_instance().initialize()
bnb.optim.GlobalOptimManager.get_instance().override_config(
p3, "optim_bits", 8
)
bnb.optim.GlobalOptimManager.get_instance().register_parameters(
[p1, p2, p3]
)
p1 = p1.cuda()
p2 = p2.cuda()
p3 = p3.cuda()
adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps)
if gtype == torch.float32:
atol, rtol = 1e-6, 1e-5
else:
atol, rtol = 1e-4, 1e-3
for i in range(50):
g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
p1.grad = g1
p2.grad = g2
p3.grad = g3
adam2.step()
assert adam2.state[p3]["state1"].dtype == torch.uint8
assert adam2.state[p3]["state2"].dtype == torch.uint8
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
optimizer_names = [
"adam8bit",
"momentum8bit",
"rmsprop8bit",
"adam8bit_blockwise",
"lars8bit",
"momentum8bit_blockwise",
"rmsprop8bit_blockwise",
]
values = list(product(dim1, dim2, gtype, optimizer_names))
names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
p2 = p1.clone()
p1 = p1.float()
blocksize = 2048
torch_optimizer = str2optimizers[optim_name][0]([p1])
bnb_optimizer = str2optimizers[optim_name][1]([p2])
if gtype == torch.float32:
atol, rtol = 3e-3, 1e-3
patol, prtol = 1e-5, 1e-3
else:
atol, rtol = 3e-3, 1e-3
patol, prtol = 1e-5, 1e-3
errors = []
relerrors = []
for i in range(50):
g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g.clone().float()
p2.grad = g.clone()
bnb_optimizer.step()
torch_optimizer.step()
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]:
# print(bnb_optimizer.state[p2][max_val], name1)
if "blockwise" in optim_name:
s1 = F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
blocksize=blocksize,
)
else:
s1 = F.dequantize(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
)
num_not_close = (
torch.isclose(
torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
)
== 0
)
assert num_not_close.sum().item() < 20
dequant_states.append(s1.clone())
err = torch.abs(p1 - p2)
relerr = err / torch.abs(p1)
assert err.mean() < 0.0001
assert relerr.mean() < 0.001
errors.append(err.mean().item())
relerrors.append(relerr.mean().item())
if i % 10 == 0 and i > 0:
for (name1, name2, qmap, max_val), s in zip(
str2statenames[optim_name], dequant_states
):
s1cpy = s.clone()
raws1cpy = bnb_optimizer.state[p2][name2].clone()
qmap1 = bnb_optimizer.state[p2][qmap].clone()
path = get_temp_dir()
torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
del bnb_optimizer
bnb_optimizer = None
bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path)
torch.testing.assert_allclose(
raws1cpy, bnb_optimizer.state[p2][name2]
)
torch.testing.assert_allclose(
qmap1, bnb_optimizer.state[p2][qmap]
)
if "blockwise" in optim_name:
s1 = F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
blocksize=blocksize,
)
else:
s1 = F.dequantize(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
)
torch.testing.assert_allclose(s1cpy, s1)
num_not_close = (
torch.isclose(
torch_optimizer.state[p1][name1],
s1,
atol=atol,
rtol=rtol,
)
== 0
)
assert num_not_close.sum().item() < 20
torch.testing.assert_allclose(
p1, p2.float(), atol=patol, rtol=prtol
)
# the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error
p1.data = p1.data.to(gtype).float()
p2.copy_(p1.data)
torch.testing.assert_allclose(p1.to(gtype), p2)
for (name1, name2, qmap, max_val), s in zip(
str2statenames[optim_name], dequant_states
):
torch_optimizer.state[p1][name1].copy_(s.data)
# print(sum(errors)/len(errors))
# print(sum(relerrors)/len(relerrors))
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32]
optim_bits = [32, 8]
values = list(product(dim1, dim2, gtype, optim_bits))
names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals)
for vals in values
]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
beta1 = 0.9
beta2 = 0.999
lr = 0.001
eps = 1e-8
p1 = p1.cuda()
p2 = p1.clone()
adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
adam2 = bnb.optim.Adam(
[p2],
lr,
(beta1, beta2),
eps,
optim_bits=optim_bits,
percentile_clipping=5,
)
gnorm_vec = torch.zeros(100).cuda()
step = 0
for i in range(50):
step += 1
g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (
0.01 * i
)
g2 = g1.clone()
p2.grad = g2
current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(
g1, gnorm_vec, step, 5
)
g1 = (g1.float() * gnorm_scale).to(gtype)
p1.grad = g1
adam1.step()
adam2.step()
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
if optim_bits == 32:
torch.testing.assert_allclose(p1, p2)
torch.testing.assert_allclose(
adam1.state[p1]["state1"],
adam2.state[p2]["state1"],
atol=5e-5,
rtol=1e-4,
)
torch.testing.assert_allclose(
adam1.state[p1]["state2"],
adam2.state[p2]["state2"],
atol=5e-5,
rtol=1e-4,
)
elif optim_bits == 8:
torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
torch.testing.assert_allclose(
adam1.state[p1]["state1"],
adam2.state[p2]["state1"],
atol=2,
rtol=1e-3,
)
torch.testing.assert_allclose(
adam1.state[p1]["state2"],
adam2.state[p2]["state2"],
atol=2,
rtol=1e-3,
)
adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
if i % 10 == 0 and i > 0:
path = get_temp_dir()
torch.save(adam2.state_dict(), join(path, "opt.pt"))
del adam2
adam2 = None
adam2 = bnb.optim.Adam(
[p2],
lr,
(beta1, beta2),
eps,
optim_bits=optim_bits,
percentile_clipping=5,
)
adam2.load_state_dict(torch.load(join(path, "opt.pt")))
dim1 = [4096]
dim2 = [4096]
gtype = [torch.float32, torch.float16]
# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
# optimizer_names = ['lamb_apex', 'lamb8bit']
# optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names = ["adam8bit_blockwise"]
values = list(product(dim1, dim2, gtype, optimizer_names))
names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
bnb_optimizer = str2optimizers[optim_name][1]([p1])
g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g
for i in range(k):
if i == k // 5:
# 100 iterations for burn-in
torch.cuda.synchronize()
t0 = time.time()
bnb_optimizer.step()
torch.cuda.synchronize()
s = time.time() - t0
print("")
params = (k - k // 5) * dim1 * dim2
print(optim_name, gtype, s / params)
# assert s < 3.9
./setup.py:20:10: F541 f-string is missing placeholders
./setup.py:21:13: F541 f-string is missing placeholders
./quicktest.py:5:1: F401 'bitsandbytes as bnb' imported but unused
./bitsandbytes/cuda_setup.py:42:56: F821 undefined name 'error_str'
./bitsandbytes/cuda_setup.py:43:15: F541 f-string is missing placeholders
./bitsandbytes/cuda_setup.py:67:5: F841 local variable 'context' is assigned to but never used
./bitsandbytes/cuda_setup.py:68:5: F841 local variable 'error_str' is assigned to but never used
./bitsandbytes/cuda_setup.py:76:9: F841 local variable 'result' is assigned to but never used
./bitsandbytes/cuda_setup.py:144:13: F841 local variable 'has_gpu' is assigned to but never used
./bitsandbytes/functional.py:294:13: F821 undefined name 'math'
./bitsandbytes/functional.py:295:16: F821 undefined name 'math'
./bitsandbytes/functional.py:303:5: F841 local variable 'ptrA' is assigned to but never used
./bitsandbytes/functional.py:304:5: F841 local variable 'ptrOut' is assigned to but never used
./bitsandbytes/functional.py:1057:17: W503 line break before binary operator
./bitsandbytes/functional.py:1058:17: W503 line break before binary operator
./bitsandbytes/functional.py:1059:17: W503 line break before binary operator
./bitsandbytes/functional.py:1649:1: F811 redefinition of unused 'get_special_format_str' from line 160
./bitsandbytes/functional.py:1687:5: F841 local variable 'ptrA' is assigned to but never used
./bitsandbytes/functional.py:1688:5: F841 local variable 'ptrOut' is assigned to but never used
./bitsandbytes/functional.py:1802:5: F841 local variable 'ccolsA' is assigned to but never used
./bitsandbytes/functional.py:1805:5: F841 local variable 'cldb' is assigned to but never used
./bitsandbytes/functional.py:1806:5: F841 local variable 'cldc' is assigned to but never used
./bitsandbytes/functional.py:1873:9: F841 local variable 'dtype' is assigned to but never used
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.MatmulLtState' imported but unused
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.bmm_cublas' imported but unused
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.matmul' imported but unused
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.matmul_cublas' imported but unused
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.mm_cublas' imported but unused
./bitsandbytes/__init__.py:9:1: F401 '.nn.modules' imported but unused
./bitsandbytes/__init__.py:12:5: F401 '.optim.adam' imported but unused
./bitsandbytes/autograd/_functions.py:5:1: F401 'bitsandbytes as bnb' imported but unused
./bitsandbytes/autograd/_functions.py:12:75: W291 trailing whitespace
./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.Int8Params' imported but unused
./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.Linear8bit' imported but unused
./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.Linear8bitLt' imported but unused
./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.StableEmbedding' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Any' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Callable' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Dict' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Iterator' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Mapping' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Set' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Tuple' imported but unused
./bitsandbytes/nn/modules.py:11:1: F401 'torch.nn.parameter.Parameter' imported but unused
./bitsandbytes/nn/modules.py:183:13: W503 line break before binary operator
./bitsandbytes/nn/modules.py:184:13: W503 line break before binary operator
./bitsandbytes/nn/modules.py:272:24: F821 undefined name 'dist'
./bitsandbytes/nn/modules.py:272:49: F821 undefined name 'dist'
./bitsandbytes/optim/optimizer.py:243:9: F841 local variable 'overflows' is assigned to but never used
./bitsandbytes/optim/optimizer.py:280:35: F541 f-string is missing placeholders
./bitsandbytes/optim/optimizer.py:283:35: F541 f-string is missing placeholders
./bitsandbytes/optim/lars.py:27:39: F541 f-string is missing placeholders
./bitsandbytes/optim/lars.py:59:39: F541 f-string is missing placeholders
./bitsandbytes/optim/lars.py:91:39: F541 f-string is missing placeholders
./bitsandbytes/optim/lars.py:157:13: F841 local variable 'params_with_grad' is assigned to but never used
./bitsandbytes/optim/lars.py:158:13: F841 local variable 'd_p_list' is assigned to but never used
./bitsandbytes/optim/lars.py:159:13: F841 local variable 'momentum_buffer_list' is assigned to but never used
./bitsandbytes/optim/lars.py:174:35: F821 undefined name 'param'
./bitsandbytes/optim/__init__.py:9:5: F401 '.adam.Adam' imported but unused
./bitsandbytes/optim/__init__.py:9:5: F401 '.adam.Adam8bit' imported but unused
./bitsandbytes/optim/__init__.py:9:5: F401 '.adam.Adam32bit' imported but unused
./bitsandbytes/optim/__init__.py:10:5: F401 '.adamw.AdamW' imported but unused
./bitsandbytes/optim/__init__.py:10:5: F401 '.adamw.AdamW8bit' imported but unused
./bitsandbytes/optim/__init__.py:10:5: F401 '.adamw.AdamW32bit' imported but unused
./bitsandbytes/optim/__init__.py:11:5: F401 '.sgd.SGD' imported but unused
./bitsandbytes/optim/__init__.py:11:5: F401 '.sgd.SGD8bit' imported but unused
./bitsandbytes/optim/__init__.py:11:5: F401 '.sgd.SGD32bit' imported but unused
./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.LARS' imported but unused
./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.LARS8bit' imported but unused
./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.LARS32bit' imported but unused
./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.PytorchLARS' imported but unused
./bitsandbytes/optim/__init__.py:13:5: F401 '.lamb.LAMB' imported but unused
./bitsandbytes/optim/__init__.py:13:5: F401 '.lamb.LAMB8bit' imported but unused
./bitsandbytes/optim/__init__.py:13:5: F401 '.lamb.LAMB32bit' imported but unused
./bitsandbytes/optim/__init__.py:14:5: F401 '.rmsprop.RMSprop' imported but unused
./bitsandbytes/optim/__init__.py:14:5: F401 '.rmsprop.RMSprop8bit' imported but unused
./bitsandbytes/optim/__init__.py:14:5: F401 '.rmsprop.RMSprop32bit' imported but unused
./bitsandbytes/optim/__init__.py:15:5: F401 '.adagrad.Adagrad' imported but unused
./bitsandbytes/optim/__init__.py:15:5: F401 '.adagrad.Adagrad8bit' imported but unused
./bitsandbytes/optim/__init__.py:15:5: F401 '.adagrad.Adagrad32bit' imported but unused
./bitsandbytes/optim/__init__.py:17:1: F401 '.optimizer.GlobalOptimManager' imported but unused
./bitsandbytes/optim/adam.py:229:21: F841 local variable 'max_exp_avg_sq' is assigned to but never used
./bitsandbytes/optim/rmsprop.py:25:39: F541 f-string is missing placeholders
./bitsandbytes/optim/rmsprop.py:27:39: F541 f-string is missing placeholders
./bitsandbytes/optim/rmsprop.py:59:39: F541 f-string is missing placeholders
./bitsandbytes/optim/rmsprop.py:61:39: F541 f-string is missing placeholders
./bitsandbytes/optim/rmsprop.py:94:39: F541 f-string is missing placeholders
./bitsandbytes/optim/rmsprop.py:96:39: F541 f-string is missing placeholders
./bitsandbytes/optim/sgd.py:24:39: F541 f-string is missing placeholders
./bitsandbytes/optim/sgd.py:55:39: F541 f-string is missing placeholders
./bitsandbytes/optim/sgd.py:86:39: F541 f-string is missing placeholders
./tests/test_optim.py:1:1: F401 'ctypes' imported but unused
./tests/test_optim.py:199:5: F841 local variable 'mask' is assigned to but never used
./tests/test_optim.py:218:9: F841 local variable 'atol' is assigned to but never used
./tests/test_optim.py:218:15: F841 local variable 'rtol' is assigned to but never used
./tests/test_optim.py:304:17: W503 line break before binary operator
./tests/test_optim.py:354:21: W503 line break before binary operator
./tests/test_autograd.py:309:13: F841 local variable 'err' is assigned to but never used
./tests/test_cuda_setup_evaluator.py:31:9: F821 undefined name 'test_dir'
./tests/test_cuda_setup_evaluator.py:33:14: F821 undefined name 'test_input'
./tests/test_cuda_setup_evaluator.py:81:32: E203 whitespace before ':'
./tests/test_functional.py:55:13: F841 local variable 'ms' is assigned to but never used
./tests/test_functional.py:177:5: F841 local variable 'diffs' is assigned to but never used
./tests/test_functional.py:178:5: F841 local variable 'reldiffs' is assigned to but never used
./tests/test_functional.py:260:5: F841 local variable 'minA' is assigned to but never used
./tests/test_functional.py:261:5: F841 local variable 'maxA' is assigned to but never used
./tests/test_functional.py:584:5: F841 local variable 'func' is assigned to but never used
./tests/test_functional.py:617:17: F841 local variable 'offset' is assigned to but never used
./tests/test_functional.py:618:17: F841 local variable 'col2' is assigned to but never used
./tests/test_functional.py:619:17: F841 local variable 'row2' is assigned to but never used
./tests/test_functional.py:705:9: F841 local variable 'C1' is assigned to but never used
./tests/test_functional.py:706:9: F841 local variable 'C2' is assigned to but never used
./tests/test_functional.py:715:9: F841 local variable 'output' is assigned to but never used
./tests/test_functional.py:750:5: F841 local variable 'formatB' is assigned to but never used
./tests/test_functional.py:754:5: F841 local variable 'w2' is assigned to but never used
./tests/test_functional.py:763:5: F841 local variable 'dtype' is assigned to but never used
./tests/test_functional.py:770:9: F841 local variable 'out1' is assigned to but never used
./tests/test_functional.py:1108:5: F841 local variable 'relerr1' is assigned to but never used
./tests/test_functional.py:1108:14: F841 local variable 'relerr2' is assigned to but never used
./tests/test_functional.py:1114:9: F841 local variable 'C1' is assigned to but never used
./tests/test_functional.py:1135:9: F841 local variable 'C4' is assigned to but never used
./tests/test_functional.py:1179:5: F841 local variable 'err1' is assigned to but never used
./tests/test_functional.py:1179:11: F841 local variable 'err2' is assigned to but never used
./tests/test_functional.py:1179:17: F841 local variable 'err3' is assigned to but never used
./tests/test_functional.py:1180:5: F841 local variable 'relerr1' is assigned to but never used
./tests/test_functional.py:1180:14: F841 local variable 'relerr2' is assigned to but never used
./tests/test_functional.py:1192:9: F841 local variable 'C1' is assigned to but never used
./tests/test_functional.py:1313:9: F841 local variable 'c' is assigned to but never used
./tests/test_functional.py:1314:9: F841 local variable 'c2' is assigned to but never used
./tests/test_functional.py:1406:9: F841 local variable 'C1' is assigned to but never used
./tests/test_functional.py:1425:9: F841 local variable 'out2' is assigned to but never used
./tests/test_functional.py:1542:5: F841 local variable 'idx_col' is assigned to but never used
./tests/test_functional.py:1566:30: E203 whitespace before ':'
./tests/test_functional.py:1568:38: E203 whitespace before ':'
./tests/test_functional.py:1655:5: F841 local variable 'offset' is assigned to but never used
./tests/test_functional.py:1706:9: F841 local variable 'out' is assigned to but never used
./tests/test_functional.py:1822:9: F841 local variable 'out' is assigned to but never used
./tests/test_functional.py:1882:5: F841 local variable 'out2' is assigned to but never used
./tests/test_functional.py:1928:9: F841 local variable 'dtype' is assigned to but never used
./tests/test_functional.py:1934:9: F841 local variable 'minx' is assigned to but never used
./tests/test_functional.py:1948:5: F841 local variable 'C0' is assigned to but never used
./tests/test_modules.py:1:1: F401 'itertools.product' imported but unused
./tests/test_modules.py:52:9: F841 local variable 'norm' is assigned to but never used
./tests/test_modules.py:52:16: F821 undefined name 'math'
./tests/test_modules.py:52:26: F821 undefined name 'math'
./tests/test_modules.py:52:37: F821 undefined name 'math'
./tests/test_modules.py:177:21: F821 undefined name 'einops'
./tests/test_modules.py:233:9: F841 local variable 'stochastic' is assigned to but never used
./tests/test_modules.py:382:9: F841 local variable 'o1' is assigned to but never used
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment