"docs/vscode:/vscode.git/clone" did not exist on "52d4449810c8e13eb22b57e706e0e03806247da2"
functional.py 73.7 KB
Newer Older
1
2
3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
Tim Dettmers's avatar
Tim Dettmers committed
4
# LICENSE file in the root directory of this source tree.
5
import ctypes as ct
Tom Aarsen's avatar
Tom Aarsen committed
6
import itertools
7
import operator
Tim Dettmers's avatar
Tim Dettmers committed
8
9
import random
import torch
Tim Dettmers's avatar
Tim Dettmers committed
10
import itertools
Tim Dettmers's avatar
Tim Dettmers committed
11
import math
Tim Dettmers's avatar
Tim Dettmers committed
12
from scipy.stats import norm
Tim Dettmers's avatar
Tim Dettmers committed
13
import numpy as np
14

Tom Aarsen's avatar
Tom Aarsen committed
15
from functools import reduce  # Required in Python 3
16
from typing import Tuple
Tim Dettmers's avatar
Tim Dettmers committed
17
18
from torch import Tensor

19
from .cextension import COMPILED_WITH_CUDA, lib
Tom Aarsen's avatar
Tom Aarsen committed
20

21
22
23
24

# math.prod not compatible with python < 3.8
def prod(iterable):
    return reduce(operator.mul, iterable, 1)
Max Ryabinin's avatar
Max Ryabinin committed
25

Tim Dettmers's avatar
Tim Dettmers committed
26
27
name2qmap = {}

Max Ryabinin's avatar
Max Ryabinin committed
28
if COMPILED_WITH_CUDA:
29
    """C FUNCTIONS FOR OPTIMIZERS"""
Max Ryabinin's avatar
Max Ryabinin committed
30
    str2optimizer32bit = {}
31
    str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    str2optimizer32bit["momentum"] = (
        lib.cmomentum32bit_g32,
        lib.cmomentum32bit_g16,
    )
    str2optimizer32bit["rmsprop"] = (
        lib.crmsprop32bit_g32,
        lib.crmsprop32bit_g16,
    )
    str2optimizer32bit["adagrad"] = (
        lib.cadagrad32bit_g32,
        lib.cadagrad32bit_g16,
    )
    str2optimizer32bit["lars"] = (
        lib.cmomentum32bit_g32,
        lib.cmomentum32bit_g16,
    )
48
    str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
Max Ryabinin's avatar
Max Ryabinin committed
49
50

    str2optimizer8bit = {}
51
52
53
54
    str2optimizer8bit["adam"] = (
        lib.cadam_static_8bit_g32,
        lib.cadam_static_8bit_g16,
    )
55
56
57
58
59
60
61
62
    str2optimizer8bit["momentum"] = (
        lib.cmomentum_static_8bit_g32,
        lib.cmomentum_static_8bit_g16,
    )
    str2optimizer8bit["rmsprop"] = (
        lib.crmsprop_static_8bit_g32,
        lib.crmsprop_static_8bit_g16,
    )
63
64
65
66
    str2optimizer8bit["lamb"] = (
        lib.cadam_static_8bit_g32,
        lib.cadam_static_8bit_g16,
    )
67
68
69
70
    str2optimizer8bit["lars"] = (
        lib.cmomentum_static_8bit_g32,
        lib.cmomentum_static_8bit_g16,
    )
Max Ryabinin's avatar
Max Ryabinin committed
71
72

    str2optimizer8bit_blockwise = {}
73
74
75
    str2optimizer8bit_blockwise["adam"] = (
        lib.cadam_8bit_blockwise_fp32,
        lib.cadam_8bit_blockwise_fp16,
Tim Dettmers's avatar
Tim Dettmers committed
76
        lib.cadam_8bit_blockwise_bf16,
77
78
79
80
81
82
83
84
85
86
87
88
89
    )
    str2optimizer8bit_blockwise["momentum"] = (
        lib.cmomentum_8bit_blockwise_fp32,
        lib.cmomentum_8bit_blockwise_fp16,
    )
    str2optimizer8bit_blockwise["rmsprop"] = (
        lib.crmsprop_8bit_blockwise_fp32,
        lib.crmsprop_8bit_blockwise_fp16,
    )
    str2optimizer8bit_blockwise["adagrad"] = (
        lib.cadagrad_8bit_blockwise_fp32,
        lib.cadagrad_8bit_blockwise_fp16,
    )
Tim Dettmers's avatar
Tim Dettmers committed
90
91


92
class CUBLAS_Context:
Tim Dettmers's avatar
Tim Dettmers committed
93
94
95
    _instance = None

    def __init__(self):
96
        raise RuntimeError("Call get_instance() instead")
Tim Dettmers's avatar
Tim Dettmers committed
97
98
99

    def initialize(self):
        self.context = {}
100
101
        # prev_device = torch.cuda.current_device()
        # for i in range(torch.cuda.device_count()):
Tim Dettmers's avatar
Tim Dettmers committed
102
103
        #    torch.cuda.set_device(torch.device('cuda', i))
        #    self.context.append(ct.c_void_p(lib.get_context()))
104
        # torch.cuda.set_device(prev_device)
Tim Dettmers's avatar
Tim Dettmers committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            cls._instance = cls.__new__(cls)
            cls._instance.initialize()
        return cls._instance

    def get_context(self, device):
        if device.index not in self.context:
            prev_device = torch.cuda.current_device()
            torch.cuda.set_device(device)
            self.context[device.index] = ct.c_void_p(lib.get_context())
            torch.cuda.set_device(prev_device)
        return self.context[device.index]

121

122
class Cusparse_Context:
Tim Dettmers's avatar
Tim Dettmers committed
123
124
125
    _instance = None

    def __init__(self):
126
        raise RuntimeError("Call get_instance() instead")
Tim Dettmers's avatar
Tim Dettmers committed
127
128
129
130
131
132
133
134
135
136

    def initialize(self):
        self.context = ct.c_void_p(lib.get_cusparse())

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            cls._instance = cls.__new__(cls)
            cls._instance.initialize()
        return cls._instance
Tim Dettmers's avatar
Tim Dettmers committed
137

138

Tim Dettmers's avatar
Tim Dettmers committed
139
def create_linear_map(signed=True, total_bits=8, add_zero=True):
140
    sign = (-1.0 if signed else 0.0)
Tim Dettmers's avatar
Tim Dettmers committed
141
142
143
144
145
146
147
148
149
    total_values = 2**total_bits
    if add_zero or total_bits < 8:
        # add a zero
        # since we simulate less bits by having zeros in the data type, we
        # we need to center the quantization around zero and as such lose
        # a single value
        total_values = (2**total_bits if not signed else 2**total_bits-1)

    values = torch.linspace(sign, 1.0, total_values)
150
151
152
    gap = 256 - values.numel()
    if gap == 0:
        return values
Tim Dettmers's avatar
Tim Dettmers committed
153
    else:
154
155
156
        l = values.numel()//2
        #return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist())
        return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
Tim Dettmers's avatar
Tim Dettmers committed
157

158
def create_custom_map(seed=0, scale=0.01):
Tim Dettmers's avatar
Tim Dettmers committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    v = [12, 10, 8, 6, 3, 2, 1]
    # 16-bit 7B 22.33, 4-bit best 22.88, FP4 23.25, 4-bit 95 22.97, 4-bit evo 22.45
    # 16-bit 13B 70.35, 4-bit best 67.16, FP4 100.78, 4-bit-95 69.39, 4-bit evo 70.48

    # 13B 100 steps:
    # - 4-bit evo: 86.02
    # - 4-bit norm: 78.73
    # - 4-bit FP4:
    # - 16-bit:

    # interval search on normal distribution
    #v = [3.090232306167813, 1.4589770349449647, 1.064410327932115, 0.7896806653244509, 0.5646884166925807, 0.3653406435875121, 0.17964844284441311] # 0.999 26.5
    #v = [2.3263478740408408, 1.4050715603096329, 1.0364333894937898, 0.7721932141886848, 0.5533847195556727, 0.3584587932511938, 0.1763741647808615] # 0.99 24.99
    #v = [1.6448536269514722, 1.2040469600267016, 0.9208229763683788, 0.6971414348463417, 0.5039653672113453, 0.3280721075316511, 0.16184416680396213] # 0.95 24.53 22.97
    #v = [1.4050715603096329, 1.0803193408149558, 0.8416212335729143, 0.643345405392917, 0.4676987991145084, 0.3054807880993974, 0.1509692154967774] # 0.92 24.81
    #v = [1.2815515655446004, 1.0062699858608395, 0.7916386077433746, 0.6084981344998837, 0.4438613119262478, 0.29050677112339396, 0.14372923370582416] # 0.9 24.68
    #v = [1.8807936081512509, 1.2980047163986055, 0.9769954022693226, 0.7341502955472268, 0.5285136765472481, 0.343225833559403, 0.16910470304375366] # 0.97 25.03
    #v = [1.7506860712521692, 1.2496468758017434, 0.9485350408266378, 0.7155233557034365, 0.5162006366043174, 0.3356393360829622, 0.16547334454641704] # 0.96 24.85 23.01
    #v = [1.5547735945968535, 1.1608220210715001, 0.893800631179489, 0.6789921163940618, 0.4918050830048072, 0.3205236191093902, 0.15821711945563585] # 0.94 24.47
    #v = [1.475791028179171, 1.1196635980209986, 0.8674156943957149, 0.6610637542614526, 0.4797170937629045, 0.31299335020578195, 0.15459215234139795] # 0.93 24.85
    #v = [1.5981931399228175, 1.1821583959486879, 0.9072289939325966, 0.6880384454306778, 0.49787602226482025, 0.3242955535308664, 0.160030379970179] # 0.945 24.287
    ##v = [1.6164363711150211, 1.1908453913294612, 0.9126463450304729, 0.6916727602238111, 0.5003095327012462, 0.3258056171348078, 0.1607558311941979] # 0.947 24.293
    #v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.207
    #v = [1.6118251211466303, 1.188665228776879, 0.9112895004060624, 0.690763326564427, 0.4997008778346997, 0.3254280317127771, 0.16057446047146948] # 0.9465 24.30
    #v = [1.6027040905517569, 1.184321770169049, 0.9085808314549837, 0.6889461706317986, 0.4984841229538408, 0.32467299997597887, 0.1602117348657326] # 0.9455 24.293
Tim Dettmers's avatar
Tim Dettmers committed
184
    v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.37 22.88
Tim Dettmers's avatar
Tim Dettmers committed
185
186
187
188
189
190
191
192
193

    # 7B evo start 
    #v = [1.62129629, 1.18870191, 0.90848106, 0.69108646, 0.50515268, 0.34927819905,  0.14122701] # 22.06
    #v = [1.6143079205628337, 1.1888081407660314, 0.8990131955745421, 0.694373759813679, 0.5083033257326773, 0.3452499746844963, 0.1148939728228951]      
    #v = [1.614442766030303, 1.189401918639665, 0.8998038168964273, 0.6953094818279475, 0.5073264599048384, 0.3449003790823619, 0.11428378427205564]

    # 13B evo start
    #v = [1.6077535089716468, 1.1914902148179205, 0.8999752421085561, 0.6967904489387543, 0.4949093928311768, 0.30920472033044544, 0.15391602735952042]
    #v = [1.586363722436466, 1.202610827188916, 0.9003332576346587, 0.6904888715206972, 0.49490974688233724, 0.2971151461329376, 0.15683230810738283]
194
    #v = [1.5842247437829478, 1.2037228884260156, 0.900369059187269, 0.6898587137788914, 0.4949097822874533, 0.2959061887131868, 0.15712393618216908]
Tim Dettmers's avatar
Tim Dettmers committed
195
196
197
198
199

    # mean evo 7B + 13B
    #v = [1.5993337549066253, 1.1965624035328402, 0.9000864380418481, 0.6925840978034195, 0.5011181210961458, 0.32040328389777434, 0.13570386022711237]

    # theoretically optiomal (0.93333)
Tim Dettmers's avatar
Tim Dettmers committed
200
    #v = [1.501085946044025, 1.1331700302595604, 0.8761428492468408, 0.6670160135425023, 0.48373855304610314, 0.3155014472579608, 0.15580024666388428] # 0.9333333333333333
Tim Dettmers's avatar
Tim Dettmers committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219

    if seed > 0:
        v = np.array(v)
        np.random.seed(seed)
        v += np.random.randn(7)*scale
        print(v.tolist())
        #v[0] +=  (np.random.randn(1)*0.001)[0]
        #v[-1] +=  (np.random.randn(1)*0.001)[0]
    #print(v[0], v[-1])
        v = v.tolist()
    values = v + [0]*(256-14) +  \
             v[::-1]

    values = torch.Tensor(values)
    values[0:7] *= -1
    values = values.sort().values
    values /= values.max()
    assert values.numel() == 256
    return values
220

221
def create_normal_map(offset=0.9677083, use_extra_value=True):
Tim Dettmers's avatar
Tim Dettmers committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240

    if use_extra_value:
        # one more positive value, this is an asymmetric type
        v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist()
        v2 = [0]*(256-15) ## we have 15 non-zero values in this data type
        v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
        v = v1 + v2 + v3
    else:
        v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist()
        v2 = [0]*(256-14) ## we have 14 non-zero values in this data type
        v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
        v = v1 + v2 + v3

    values = torch.Tensor(v)
    values = values.sort().values
    values /= values.max()
    assert values.numel() == 256
    return values

Tim Dettmers's avatar
Tim Dettmers committed
241
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
Tim Dettmers's avatar
Tim Dettmers committed
242
243
    e = exponent_bits
    p = precision_bits
Tim Dettmers's avatar
Tim Dettmers committed
244
245
    has_sign = 1 if signed else 0
    assert e+p == total_bits-has_sign
Tim Dettmers's avatar
Tim Dettmers committed
246
247
248
    # the exponent is biased to 2^(e-1) -1 == 0
    evalues = []
    pvalues = []
Tim Dettmers's avatar
Tim Dettmers committed
249
    for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)):
Tim Dettmers's avatar
Tim Dettmers committed
250
251
252
253
        evalues.append(2**val)


    values = []
Tim Dettmers's avatar
Tim Dettmers committed
254
255
    lst = list(itertools.product([0, 1], repeat=precision_bits))
    #for ev in evalues:
Tim Dettmers's avatar
Tim Dettmers committed
256
    bias = 2**(exponent_bits-1)-1
Tim Dettmers's avatar
Tim Dettmers committed
257
258
259
260
261
262
263
    for evalue in range(2**(exponent_bits)):
        for bit_pattern in lst:
            value = (1 if evalue != 0 else 0)
            for i, pval in enumerate(list(bit_pattern)):
                value += pval*(2**-(i+1))
            if evalue == 0:
                # subnormals
Tim Dettmers's avatar
Tim Dettmers committed
264
                value = value*2**-(bias-1)
Tim Dettmers's avatar
Tim Dettmers committed
265
266
            else:
                # normals
Tim Dettmers's avatar
Tim Dettmers committed
267
                value = value*2**-(evalue-bias-2)
Tim Dettmers's avatar
Tim Dettmers committed
268
            values.append(value)
Tim Dettmers's avatar
Tim Dettmers committed
269
            if signed:
Tim Dettmers's avatar
Tim Dettmers committed
270
271
272
273
274
                values.append(-value)


    assert len(values) == 2**total_bits
    values.sort()
Tim Dettmers's avatar
Tim Dettmers committed
275
276
277
278
    if total_bits < 8:
        gap = 256 - len(values)
        for i in range(gap):
            values.append(0)
Tim Dettmers's avatar
Tim Dettmers committed
279
280
    values.sort()
    code = torch.Tensor(values)
281
    code /= code.max()
Tim Dettmers's avatar
Tim Dettmers committed
282
283
284
285
286

    return code



Tim Dettmers's avatar
Tim Dettmers committed
287
def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
288
    """
Tim Dettmers's avatar
Tim Dettmers committed
289
290
291
292
293
294
295
296
297
298
299
300
301
    Creates the dynamic quantiztion map.

    The dynamic data type is made up of a dynamic exponent and
    fraction. As the exponent increase from 0 to -7 the number
    of bits available for the fraction shrinks.

    This is a generalization of the dynamic type where a certain
    number of the bits and be reserved for the linear quantization
    region (the fraction). n determines the maximum number of
    exponent bits.

    For more details see
    (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]
302
    """
Tim Dettmers's avatar
Tim Dettmers committed
303
304
305
306
307

    data = []
    # these are additional items that come from the case
    # where all the exponent bits are zero and no
    # indicator bit is present
Tim Dettmers's avatar
Tim Dettmers committed
308
309
    non_sign_bits = total_bits - (1 if signed else 0)
    additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
310
311
    if not signed:
        additional_items = 2 * additional_items
Tim Dettmers's avatar
Tim Dettmers committed
312
313
    for i in range(max_exponent_bits):
        fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1))
Tim Dettmers's avatar
Tim Dettmers committed
314
        boundaries = torch.linspace(0.1, 1, fraction_items)
315
        means = (boundaries[:-1] + boundaries[1:]) / 2.0
Tim Dettmers's avatar
Tim Dettmers committed
316
        data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
317
        if signed:
Tim Dettmers's avatar
Tim Dettmers committed
318
            data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
319

Tim Dettmers's avatar
Tim Dettmers committed
320
321
322
323
324
325
        if additional_items > 0:
            boundaries = torch.linspace(0.1, 1, additional_items + 1)
            means = (boundaries[:-1] + boundaries[1:]) / 2.0
            data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
            if signed:
                data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
326
327
328

    data.append(0)
    data.append(1.0)
Tim Dettmers's avatar
Tim Dettmers committed
329
330
331
332
333

    gap = 256 - len(data)
    for i in range(gap):
        data.append(0)

Tim Dettmers's avatar
Tim Dettmers committed
334
335
336
    data.sort()
    return Tensor(data)

Tim Dettmers's avatar
Tim Dettmers committed
337
338
339
340
341
342
343
344
345
346
347
348
349
350
def create_quantile_map(A, total_bits=8):
    q = estimate_quantiles(A, num_quantiles=2**total_bits-1)
    q = q.tolist()
    q.append(0)

    gap = 256 - len(q)
    for i in range(gap):
        q.append(0)

    q.sort()

    q = Tensor(q)
    q = q/q.abs().max()
    return q
351

Tim Dettmers's avatar
Tim Dettmers committed
352
def get_special_format_str():
353
    if not torch.cuda.is_available(): return 'col_turing'
Tom Aarsen's avatar
Tom Aarsen committed
354
    major, _minor = torch.cuda.get_device_capability()
355
    if major <= 7:
356
        return "col_turing"
Tom Aarsen's avatar
Tom Aarsen committed
357
    if major == 8:
358
        return "col_ampere"
Tom Aarsen's avatar
Tom Aarsen committed
359
    return "col_turing"
360

Tim Dettmers's avatar
Tim Dettmers committed
361

362
363
364
365
366
367
368
369

def is_on_gpu(tensors):
    on_gpu = True
    for t in tensors:
        if t is None: continue # NULL pointers are fine
        on_gpu &= t.device.type == 'cuda'
    return on_gpu

Tim Dettmers's avatar
Tim Dettmers committed
370
def get_ptr(A: Tensor) -> ct.c_void_p:
371
    """
Tim Dettmers's avatar
Tim Dettmers committed
372
373
374
375
376
377
378
379
380
381
    Get the ctypes pointer from a PyTorch Tensor.

    Parameters
    ----------
    A : torch.tensor
        The PyTorch tensor.

    Returns
    -------
    ctypes.c_void_p
382
383
384
385
    """
    if A is None:
        return None
    else:
386
        return ct.c_void_p(A.data.data_ptr())
387

Tim Dettmers's avatar
Tim Dettmers committed
388

Tim Dettmers's avatar
Tim Dettmers committed
389
390
391
392
393
def pre_call(device):
    prev_device = torch.cuda.current_device()
    torch.cuda.set_device(device)
    return prev_device

394

Tim Dettmers's avatar
Tim Dettmers committed
395
396
397
def post_call(prev_device):
    torch.cuda.set_device(prev_device)

398

Tim Dettmers's avatar
Tim Dettmers committed
399
400
401
402
def get_transform_func(dtype, orderA, orderOut, transpose=False):
    name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}'
    if not hasattr(lib, name):
        print(name)
403
404
405
        raise ValueError(
            f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}"
        )
Tim Dettmers's avatar
Tim Dettmers committed
406
407
408
    else:
        return getattr(lib, name)

409
410
411
412
413

def get_transform_buffer(
    shape, dtype, device, to_order, from_order="row", transpose=False
):
    # init_func = torch.empty
Tim Dettmers's avatar
Tim Dettmers committed
414
415
416
417
418
419
    init_func = torch.zeros
    dims = len(shape)

    if dims == 2:
        rows = shape[0]
    elif dims == 3:
420
        rows = shape[0] * shape[1]
Tim Dettmers's avatar
Tim Dettmers committed
421
422
423
424
425
426
427
428
429
430
    cols = shape[-1]

    state = (shape, to_order)
    if transpose:
        # swap dims
        tmp = rows
        rows = cols
        cols = tmp
        state = (shape[::-1], to_order)

431
    if to_order == "row" or to_order == "col":
Tim Dettmers's avatar
Tim Dettmers committed
432
        return init_func(shape, dtype=dtype, device=device), state
433
    elif to_order == "col32":
Tim Dettmers's avatar
Tim Dettmers committed
434
        # blocks of 32 columns (padded)
435
        cols = 32 * ((cols + 31) // 32)
Tim Dettmers's avatar
Tim Dettmers committed
436
        return init_func((rows, cols), dtype=dtype, device=device), state
437
    elif to_order == "col_turing":
Tim Dettmers's avatar
Tim Dettmers committed
438
        # blocks of 32 columns and 8 rows
439
440
        cols = 32 * ((cols + 31) // 32)
        rows = 8 * ((rows + 7) // 8)
Tim Dettmers's avatar
Tim Dettmers committed
441
        return init_func((rows, cols), dtype=dtype, device=device), state
442
    elif to_order == "col_ampere":
Tim Dettmers's avatar
Tim Dettmers committed
443
        # blocks of 32 columns and 32 rows
444
445
        cols = 32 * ((cols + 31) // 32)
        rows = 32 * ((rows + 31) // 32)
Tim Dettmers's avatar
Tim Dettmers committed
446
447
        return init_func((rows, cols), dtype=dtype, device=device), state
    else:
448
449
        raise NotImplementedError(f"To_order not supported: {to_order}")

Tim Dettmers's avatar
Tim Dettmers committed
450

451
def nvidia_transform(
452
453
454
455
456
457
458
    A,
    to_order,
    from_order="row",
    out=None,
    transpose=False,
    state=None,
    ld=None,
459
460
461
462
463
464
465
466
467
468
469
):
    if state is None:
        state = (A.shape, from_order)
    else:
        from_order = state[1]
    if out is None:
        out, new_state = get_transform_buffer(
            state[0], A.dtype, A.device, to_order, state[1]
        )
    else:
        new_state = (state[1], to_order)
Tim Dettmers's avatar
Tim Dettmers committed
470
471
472
473
474
475
476
    func = get_transform_func(A.dtype, from_order, to_order, transpose)

    shape = state[0]
    if len(shape) == 2:
        dim1 = ct.c_int32(shape[0])
        dim2 = ct.c_int32(shape[1])
    elif ld is not None:
477
478
        n = prod(shape)
        dim1 = prod([shape[i] for i in ld])
479
        dim2 = ct.c_int32(n // dim1)
Tim Dettmers's avatar
Tim Dettmers committed
480
481
        dim1 = ct.c_int32(dim1)
    else:
482
        dim1 = ct.c_int32(shape[0] * shape[1])
Tim Dettmers's avatar
Tim Dettmers committed
483
484
485
486
487
488
489
        dim2 = ct.c_int32(shape[2])

    ptr = CUBLAS_Context.get_instance().get_context(A.device)
    func(ptr, get_ptr(A), get_ptr(out), dim1, dim2)

    return out, new_state

490

Tim Dettmers's avatar
Tim Dettmers committed
491
def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor:
Tim Dettmers's avatar
Tim Dettmers committed
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    '''
    Estimates 256 equidistant quantiles on the input tensor eCDF.

    Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
    via the eCDF of the input tensor `A`. This is a fast but approximate algorithm
    and the extreme quantiles close to 0 and 1 have high variance / large estimation
    errors. These large errors can be avoided by using the offset variable which trims
    the distribution. The default offset value of 1/512 ensures minimum entropy encoding -- it
    trims 1/512 = 0.2% from each side of the distrivution. An offset value of 0.01 to 0.02
    usually has a much lower error but is not a minimum entropy encoding. Given an offset
    of 0.02 equidistance points in the range [0.02, 0.98] are used for the quantiles.

    Parameters
    ----------
    A : torch.Tensor
        The input tensor. Any shape.
    out : torch.Tensor
        Tensor with the 256 estimated quantiles.
    offset : float
Tim Dettmers's avatar
Tim Dettmers committed
511
512
513
        The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles)
    num_quantiles : int
        The number of equally spaced quantiles.
Tim Dettmers's avatar
Tim Dettmers committed
514
515
516
517
518
519

    Returns
    -------
    torch.Tensor:
        The 256 quantiles in float32 datatype.
    '''
Tim Dettmers's avatar
Tim Dettmers committed
520
521
522
523
524
525
    if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.')
    if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}")
    if num_quantiles < 256 and offset == 1/(512):
        # override default arguments
        offset = 1/(2*num_quantiles)

Tim Dettmers's avatar
Tim Dettmers committed
526
    if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
527
    is_on_gpu([A, out])
Tim Dettmers's avatar
Tim Dettmers committed
528
    device = pre_call(A.device)
Tim Dettmers's avatar
Tim Dettmers committed
529
    if A.dtype == torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
530
        lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
531
    elif A.dtype == torch.float16:
Tim Dettmers's avatar
Tim Dettmers committed
532
        lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
533
    else:
534
        raise NotImplementedError(f"Not supported data type {A.dtype}")
Tim Dettmers's avatar
Tim Dettmers committed
535
536
537
    post_call(device)

    if num_quantiles < 256:
Tim Dettmers's avatar
Tim Dettmers committed
538
        step = round(256/num_quantiles)
Tim Dettmers's avatar
Tim Dettmers committed
539
540
541
        idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
        out = out[idx]

Tim Dettmers's avatar
Tim Dettmers committed
542
543
    return out

544

545
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096) -> Tensor:
546
    """
Tim Dettmers's avatar
Tim Dettmers committed
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
    Quantize tensor A in blocks of size 4096 values.

    Quantizes tensor A by dividing it into blocks of 4096 values.
    Then the absolute maximum value within these blocks is calculated
    for the non-linear quantization.

    Parameters
    ----------
    A : torch.Tensor
        The input tensor.
    code : torch.Tensor
        The quantization map.
    absmax : torch.Tensor
        The absmax values.
    rand : torch.Tensor
        The tensor for stochastic rounding.
    out : torch.Tensor
        The output tensor (8-bit).

    Returns
    -------
    torch.Tensor:
        The 8-bit tensor.
    tuple(torch.Tensor, torch.Tensor):
        The quantization state to undo the quantization.
572
    """
Tim Dettmers's avatar
Tim Dettmers committed
573

574

Tim Dettmers's avatar
Tim Dettmers committed
575
    if code is None:
576
577
578
        if "dynamic" not in name2qmap:
            name2qmap["dynamic"] = create_dynamic_map().to(A.device)
        code = name2qmap["dynamic"]
Tim Dettmers's avatar
Tim Dettmers committed
579
580
581

    if absmax is None:
        n = A.numel()
582
583
        blocks = n // blocksize
        blocks += 1 if n % blocksize > 0 else 0
Tim Dettmers's avatar
Tim Dettmers committed
584
585
        absmax = torch.zeros((blocks,), device=A.device)

586
587
    if out is None:
        out = torch.zeros_like(A, dtype=torch.uint8)
Tim Dettmers's avatar
Tim Dettmers committed
588
589

    if A.device.type != 'cpu':
Tim Dettmers's avatar
Tim Dettmers committed
590
        assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]
591
        cblocksize = ct.c_int32(blocksize)
592
593
        prev_device = pre_call(A.device)
        code = code.to(A.device)
Tim Dettmers's avatar
Tim Dettmers committed
594
        if rand is not None:
595
            is_on_gpu([code, A, out, absmax, rand])
596
            assert blocksize==4096
Tim Dettmers's avatar
Tim Dettmers committed
597
598
599
            assert rand.numel() >= 1024
            rand_offset = random.randint(0, 1023)
            if A.dtype == torch.float32:
600
                lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
601
            elif A.dtype == torch.float16:
602
                lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
603
            else:
604
                raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
Tim Dettmers's avatar
Tim Dettmers committed
605
        else:
606
            is_on_gpu([code, A, out, absmax])
Tim Dettmers's avatar
Tim Dettmers committed
607
            if A.dtype == torch.float32:
608
                lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
609
            elif A.dtype == torch.float16:
610
                lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
611
            else:
612
                raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
613
        post_call(A.device)
Tim Dettmers's avatar
Tim Dettmers committed
614
615
    else:
        # cpu
616
        code = code.cpu()
Tim Dettmers's avatar
Tim Dettmers committed
617
        assert rand is None
618
        lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
619

620
621
622
    state = (absmax, code, blocksize)

    return out, state
Tim Dettmers's avatar
Tim Dettmers committed
623

624
625
626
627
628
629
630
631
632
633

def dequantize_blockwise(
    A: Tensor,
    quant_state: Tuple[Tensor, Tensor] = None,
    absmax: Tensor = None,
    code: Tensor = None,
    out: Tensor = None,
    blocksize: int = 4096,
) -> Tensor:
    """
Tim Dettmers's avatar
Tim Dettmers committed
634
635
636
637
638
639
640
641
642
643
    Dequantizes blockwise quantized values.

    Dequantizes the tensor A with maximum absolute values absmax in
    blocks of size 4096.

    Parameters
    ----------
    A : torch.Tensor
        The input 8-bit tensor.
    quant_state : tuple(torch.Tensor, torch.Tensor)
644
        Tuple of code and absmax values.
Tim Dettmers's avatar
Tim Dettmers committed
645
646
647
648
649
650
651
652
653
654
655
656
    absmax : torch.Tensor
        The absmax values.
    code : torch.Tensor
        The quantization map.
    out : torch.Tensor
        Dequantized output tensor (default: float32)


    Returns
    -------
    torch.Tensor:
        Dequantized tensor (default: float32)
657
    """
Tim Dettmers's avatar
Tim Dettmers committed
658
659
    assert quant_state is not None or absmax is not None
    if code is None and quant_state is None:
660
661
662
        if "dynamic" not in name2qmap:
            name2qmap["dynamic"] = create_dynamic_map().to(A.device)
        code = name2qmap["dynamic"]
Tim Dettmers's avatar
Tim Dettmers committed
663

664
665
666
    if out is None:
        out = torch.zeros_like(A, dtype=torch.float32)
    if quant_state is None:
667
        quant_state = (absmax, code, blocksize)
668
    else:
669
        absmax, code, blocksize = quant_state
Tim Dettmers's avatar
Tim Dettmers committed
670
671
672


    if A.device.type != 'cpu':
673
674
        device = pre_call(A.device)
        code = code.to(A.device)
Tim Dettmers's avatar
Tim Dettmers committed
675
        if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64, 32]:
676
            raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
677
        is_on_gpu([A, absmax, out])
Tim Dettmers's avatar
Tim Dettmers committed
678
        if out.dtype == torch.float32:
679
            lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
680
        elif out.dtype == torch.float16:
681
            lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
682
        else:
Tim Dettmers's avatar
Tim Dettmers committed
683
            raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
684
        post_call(A.device)
Tim Dettmers's avatar
Tim Dettmers committed
685
    else:
686
        code = code.cpu()
687
        lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
688
689
690

    return out

Tim Dettmers's avatar
Tim Dettmers committed
691
692
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
    return quantize_4bit_packed(A, absmax, out, blocksize, compress_statistics, 'fp4')
Tim Dettmers's avatar
Tim Dettmers committed
693

Tim Dettmers's avatar
Tim Dettmers committed
694
695
696
697
def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
    return quantize_4bit_packed(A, absmax, out, blocksize, compress_statistics, 'nf4')

def quantize_4bit_packed(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor:
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
    """
    Quantize tensor A in blocks of FP4 values.

    Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.

    Parameters
    ----------
    A : torch.Tensor
        The input tensor.
    absmax : torch.Tensor
        The absmax values.
    out : torch.Tensor
        The output tensor (8-bit).
    blocksize : int
        The blocksize used in quantization.
Tim Dettmers's avatar
Tim Dettmers committed
713
714
    quant_type : str
        The 4-bit quantization data type {fp4, nf4}
715
716
717
718
719

    Returns
    -------
    torch.Tensor:
        The 8-bit tensor with packed 4-bit values.
Tim Dettmers's avatar
Tim Dettmers committed
720
    tuple(torch.Tensor, torch.Size, torch.dtype, int):
721
722
723
724
        The quantization state to undo the quantization.
    """
    if A.device.type != 'cuda':
        raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}')
Tim Dettmers's avatar
Tim Dettmers committed
725
726
    if quant_type not in ['fp4', 'nf4']:
        raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.')
727
728
729
730
731
732
733
734
735
736
737

    n = A.numel()
    input_shape = A.shape

    if absmax is None:
        blocks = n // blocksize
        blocks += 1 if n % blocksize > 0 else 0
        absmax = torch.zeros((blocks,), device=A.device)


    if out is None:
Tim Dettmers's avatar
Tim Dettmers committed
738
        out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device)
739

740
    assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]
741
742
743
744
745

    prev_device = pre_call(A.device)
    is_on_gpu([A, out, absmax])

    if A.dtype == torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
746
747
748
749
        if quant_type == 'fp4':
            lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
        else:
            lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
750
    elif A.dtype == torch.float16:
Tim Dettmers's avatar
Tim Dettmers committed
751
752
753
754
        if quant_type == 'fp4':
            lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
        else:
            lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
755
756
757
758
    else:
        raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
    post_call(A.device)

759
760
761
762
763
764
765
766
767
768
769
    if compress_statistics:
        offset = absmax.mean()
        absmax -= offset
        #code = create_custom_map().to(absmax.device)
        #qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256)
        qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
        del absmax
        state = (qabsmax, input_shape, A.dtype, blocksize, (offset, state2))
    else:
        state = (absmax, input_shape, A.dtype, blocksize, None)

770
771
    return out, state

Tim Dettmers's avatar
Tim Dettmers committed
772
773
774
775
776
def dequantize_fp4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
    return dequantize_4bit_packed(A, quant_state, absmax, out, blocksize, 'fp4')

def dequantize_nf4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
    return dequantize_4bit_packed(A, quant_state, absmax, out, blocksize, 'nf4')
777

Tim Dettmers's avatar
Tim Dettmers committed
778
def dequantize_4bit_packed(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor:
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
    """
    Dequantizes FP4 blockwise quantized values.

    Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize.

    Parameters
    ----------
    A : torch.Tensor
        The input 8-bit tensor (packed 4-bit values).
    quant_state : tuple(torch.Tensor, torch.Size, torch.dtype)
        Tuple of absmax values, original tensor shape and original dtype.
    absmax : torch.Tensor
        The absmax values.
    out : torch.Tensor
        Dequantized output tensor.
Tim Dettmers's avatar
Tim Dettmers committed
794
795
796
797
    blocksize : int
        The blocksize used in quantization.
    quant_type : str
        The 4-bit quantization data type {fp4, nf4}
798
799
800
801
802
803
804
805
806


    Returns
    -------
    torch.Tensor:
        Dequantized tensor.
    """
    if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
        raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
Tim Dettmers's avatar
Tim Dettmers committed
807
808
    if quant_type not in ['fp4', 'nf4']:
        raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.')
809
810
811
812
813
814

    if quant_state is None:
        assert absmax is not None and out is not None
        shape = out.shape
        dtype = out.dtype
    else:
815
        absmax, shape, dtype, blocksize, compressed_stats = quant_state
816

817
818
819
820
    if compressed_stats is not None:
        offset, state2 = compressed_stats
        absmax = dequantize_blockwise(absmax, state2)
        absmax += offset
821
822
823
824
825
826

    if out is None:
        out = torch.empty(shape, dtype=dtype, device=A.device)

    n = out.numel()

Tim Dettmers's avatar
Tim Dettmers committed
827

828
829
830
    device = pre_call(A.device)
    is_on_gpu([A, absmax, out])
    if out.dtype == torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
831
832
833
834
        if quant_type == 'fp4':
            lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
        else:
            lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
835
    elif out.dtype == torch.float16:
Tim Dettmers's avatar
Tim Dettmers committed
836
837
838
839
        if quant_type == 'fp4':
            lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
        else:
            lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
840
841
842
843
    else:
        raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
    post_call(A.device)

Tim Dettmers's avatar
Tim Dettmers committed
844
845
846
    is_transposed = (True if A.shape[0] == 1 else False)
    if is_transposed: return out.t()
    else: return out
847
848


849
def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
Tim Dettmers's avatar
Tim Dettmers committed
850
    if code is None:
851
852
853
        if "dynamic" not in name2qmap:
            name2qmap["dynamic"] = create_dynamic_map().to(A.device)
        code = name2qmap["dynamic"]
Tim Dettmers's avatar
Tim Dettmers committed
854
855
856
        code = code.to(A.device)

    absmax = torch.abs(A).max()
857
    inp = A / absmax
Tim Dettmers's avatar
Tim Dettmers committed
858
859
860
    out = quantize_no_absmax(inp, code, out)
    return out, (absmax, code)

861
862
863
864
865
866
867
868

def dequantize(
    A: Tensor,
    quant_state: Tuple[Tensor, Tensor] = None,
    absmax: Tensor = None,
    code: Tensor = None,
    out: Tensor = None,
) -> Tensor:
Tim Dettmers's avatar
Tim Dettmers committed
869
870
    assert quant_state is not None or absmax is not None
    if code is None and quant_state is None:
871
872
873
        if "dynamic" not in name2qmap:
            name2qmap["dynamic"] = create_dynamic_map().to(A.device)
        code = name2qmap["dynamic"]
Tim Dettmers's avatar
Tim Dettmers committed
874
875
        code = code.to(A.device)

876
877
    if quant_state is None:
        quant_state = (absmax, code)
Tim Dettmers's avatar
Tim Dettmers committed
878
    out = dequantize_no_absmax(A, quant_state[1], out)
879
    return out * quant_state[0]
Tim Dettmers's avatar
Tim Dettmers committed
880

881
882

def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
Tim Dettmers's avatar
Tim Dettmers committed
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
    '''
    Quantizes input tensor to 8-bit.

    Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
    `out` using the quantization map `code`.

    Parameters
    ----------
    A : torch.Tensor
        The input tensor.
    code : torch.Tensor
        The quantization map.
    out : torch.Tensor, optional
        The output tensor. Needs to be of type byte.

    Returns
    -------
    torch.Tensor:
        Quantized 8-bit tensor.
    '''
    if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
904
    is_on_gpu([A, out])
Tim Dettmers's avatar
Tim Dettmers committed
905
906
907
    lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
    return out

908
909

def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
Tim Dettmers's avatar
Tim Dettmers committed
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
    '''
    Dequantizes the 8-bit tensor to 32-bit.

    Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
    the quantization map `code`.

    Parameters
    ----------
    A : torch.Tensor
        The 8-bit input tensor.
    code : torch.Tensor
        The quantization map.
    out : torch.Tensor
        The 32-bit output tensor.

    Returns
    -------
    torch.Tensor:
        32-bit output tensor.
    '''
    if out is None: out = torch.zeros_like(A, dtype=torch.float32)
931
    is_on_gpu([code, A, out])
Tim Dettmers's avatar
Tim Dettmers committed
932
933
934
    lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
    return out

935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953

def optimizer_update_32bit(
    optimizer_name: str,
    g: Tensor,
    p: Tensor,
    state1: Tensor,
    beta1: float,
    eps: float,
    step: int,
    lr: float,
    state2: Tensor = None,
    beta2: float = 0.0,
    weight_decay: float = 0.0,
    gnorm_scale: float = 1.0,
    unorm_vec: Tensor = None,
    max_unorm: float = 0.0,
    skip_zeros=False,
) -> None:
    """
Tim Dettmers's avatar
Tim Dettmers committed
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
    Performs an inplace optimizer update with one or two optimizer states.

    Universal optimizer update for 32-bit state and 32/16-bit gradients/weights.

    Parameters
    ----------
    optimizer_name : str
        The name of the optimizer: {adam}.
    g : torch.Tensor
        Gradient tensor.
    p : torch.Tensor
        Parameter tensor.
    state1 : torch.Tensor
        Optimizer state 1.
    beta1 : float
        Optimizer beta1.
    eps : float
        Optimizer epsilon.
    weight_decay : float
        Weight decay.
    step : int
        Current optimizer step.
    lr : float
        The learning rate.
    state2 : torch.Tensor
        Optimizer state 2.
    beta2 : float
        Optimizer beta2.
    gnorm_scale : float
        The factor to rescale the gradient to the max clip value.
984
985
986
987
988
989
    unorm_vec : torch.Tensor
        The tensor for the update norm.
    max_unorm : float
        The maximum update norm relative to the weight norm.
    skip_zeros : bool
        Whether to skip zero-valued gradients or not (default: False).
990
    """
Tim Dettmers's avatar
Tim Dettmers committed
991
992
993
994
995
996

    param_norm = 0.0
    if max_unorm > 0.0:
        param_norm = torch.norm(p.data.float())

    if optimizer_name not in str2optimizer32bit:
997
998
999
        raise NotImplementedError(
            f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}'
        )
Tim Dettmers's avatar
Tim Dettmers committed
1000
1001

    if g.dtype == torch.float32 and state1.dtype == torch.float32:
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
        str2optimizer32bit[optimizer_name][0](
            get_ptr(g),
            get_ptr(p),
            get_ptr(state1),
            get_ptr(state2),
            get_ptr(unorm_vec),
            ct.c_float(max_unorm),
            ct.c_float(param_norm),
            ct.c_float(beta1),
            ct.c_float(beta2),
            ct.c_float(eps),
            ct.c_float(weight_decay),
            ct.c_int32(step),
            ct.c_float(lr),
            ct.c_float(gnorm_scale),
            ct.c_bool(skip_zeros),
            ct.c_int32(g.numel()),
        )
Tim Dettmers's avatar
Tim Dettmers committed
1020
    elif g.dtype == torch.float16 and state1.dtype == torch.float32:
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
        str2optimizer32bit[optimizer_name][1](
            get_ptr(g),
            get_ptr(p),
            get_ptr(state1),
            get_ptr(state2),
            get_ptr(unorm_vec),
            ct.c_float(max_unorm),
            ct.c_float(param_norm),
            ct.c_float(beta1),
            ct.c_float(beta2),
            ct.c_float(eps),
            ct.c_float(weight_decay),
            ct.c_int32(step),
            ct.c_float(lr),
            ct.c_float(gnorm_scale),
            ct.c_bool(skip_zeros),
            ct.c_int32(g.numel()),
        )
Tim Dettmers's avatar
Tim Dettmers committed
1039
    else:
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
        raise ValueError(
            f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
        )


def optimizer_update_8bit(
    optimizer_name: str,
    g: Tensor,
    p: Tensor,
    state1: Tensor,
    state2: Tensor,
    beta1: float,
    beta2: float,
    eps: float,
    step: int,
    lr: float,
    qmap1: Tensor,
    qmap2: Tensor,
    max1: Tensor,
    max2: Tensor,
    new_max1: Tensor,
    new_max2: Tensor,
    weight_decay: float = 0.0,
    gnorm_scale: float = 1.0,
    unorm_vec: Tensor = None,
    max_unorm: float = 0.0,
) -> None:
    """
Tim Dettmers's avatar
Tim Dettmers committed
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
    Performs an inplace Adam update.

    Universal Adam update for 32/8-bit state and 32/16-bit gradients/weights.
    Uses AdamW formulation if weight decay > 0.0.

    Parameters
    ----------
    optimizer_name : str
        The name of the optimizer. Choices {adam, momentum}
    g : torch.Tensor
        Gradient tensor.
    p : torch.Tensor
        Parameter tensor.
    state1 : torch.Tensor
        Adam state 1.
    state2 : torch.Tensor
        Adam state 2.
    beta1 : float
        Adam beta1.
    beta2 : float
        Adam beta2.
    eps : float
        Adam epsilon.
    weight_decay : float
        Weight decay.
    step : int
        Current optimizer step.
    lr : float
        The learning rate.
    qmap1 : torch.Tensor
        Quantization map for first Adam state.
    qmap2 : torch.Tensor
        Quantization map for second Adam state.
    max1 : torch.Tensor
        Max value for first Adam state update.
    max2 : torch.Tensor
        Max value for second Adam state update.
    new_max1 : torch.Tensor
        Max value for the next Adam update of the first state.
    new_max2 : torch.Tensor
        Max value for the next Adam update of the second state.
    gnorm_scale : float
        The factor to rescale the gradient to the max clip value.
1111
1112
1113
1114
    unorm_vec : torch.Tensor
        The tensor for the update norm.
    max_unorm : float
        The maximum update norm relative to the weight norm.
1115
    """
Tim Dettmers's avatar
Tim Dettmers committed
1116
1117
1118
1119
1120
1121

    param_norm = 0.0
    if max_unorm > 0.0:
        param_norm = torch.norm(p.data.float())

    if g.dtype == torch.float32 and state1.dtype == torch.uint8:
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
        str2optimizer8bit[optimizer_name][0](
            get_ptr(p),
            get_ptr(g),
            get_ptr(state1),
            get_ptr(state2),
            get_ptr(unorm_vec),
            ct.c_float(max_unorm),
            ct.c_float(param_norm),
            ct.c_float(beta1),
            ct.c_float(beta2),
            ct.c_float(eps),
            ct.c_int32(step),
            ct.c_float(lr),
            get_ptr(qmap1),
            get_ptr(qmap2),
            get_ptr(max1),
            get_ptr(max2),
            get_ptr(new_max1),
            get_ptr(new_max2),
            ct.c_float(weight_decay),
            ct.c_float(gnorm_scale),
            ct.c_int32(g.numel()),
        )
Tim Dettmers's avatar
Tim Dettmers committed
1145
    elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
        str2optimizer8bit[optimizer_name][1](
            get_ptr(p),
            get_ptr(g),
            get_ptr(state1),
            get_ptr(state2),
            get_ptr(unorm_vec),
            ct.c_float(max_unorm),
            ct.c_float(param_norm),
            ct.c_float(beta1),
            ct.c_float(beta2),
            ct.c_float(eps),
            ct.c_int32(step),
            ct.c_float(lr),
            get_ptr(qmap1),
            get_ptr(qmap2),
            get_ptr(max1),
            get_ptr(max2),
            get_ptr(new_max1),
            get_ptr(new_max2),
            ct.c_float(weight_decay),
            ct.c_float(gnorm_scale),
            ct.c_int32(g.numel()),
        )
Tim Dettmers's avatar
Tim Dettmers committed
1169
    else:
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
        raise ValueError(
            f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
        )


def optimizer_update_8bit_blockwise(
    optimizer_name: str,
    g: Tensor,
    p: Tensor,
    state1: Tensor,
    state2: Tensor,
    beta1: float,
    beta2: float,
    eps: float,
    step: int,
    lr: float,
    qmap1: Tensor,
    qmap2: Tensor,
    absmax1: Tensor,
    absmax2: Tensor,
    weight_decay: float = 0.0,
    gnorm_scale: float = 1.0,
    skip_zeros=False,
) -> None:
Tim Dettmers's avatar
Tim Dettmers committed
1194

Tim Dettmers's avatar
Tim Dettmers committed
1195
    optim_func = None
Tim Dettmers's avatar
Tim Dettmers committed
1196
    if g.dtype == torch.float32 and state1.dtype == torch.uint8:
Tim Dettmers's avatar
Tim Dettmers committed
1197
        optimizer_func = str2optimizer8bit_blockwise[optimizer_name][0]
Tim Dettmers's avatar
Tim Dettmers committed
1198
    elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
Tim Dettmers's avatar
Tim Dettmers committed
1199
1200
1201
1202
        optimizer_func = str2optimizer8bit_blockwise[optimizer_name][1]
    elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and
          len(str2optimizer8bit_blockwise[optimizer_name])==3):
        optimizer_func = str2optimizer8bit_blockwise[optimizer_name][2]
Tim Dettmers's avatar
Tim Dettmers committed
1203
    else:
1204
1205
1206
        raise ValueError(
            f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
        )
Tim Dettmers's avatar
Tim Dettmers committed
1207

Tim Dettmers's avatar
Tim Dettmers committed
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
    is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2])

    prev_device = pre_call(g.device)
    optimizer_func(
        get_ptr(p),
        get_ptr(g),
        get_ptr(state1),
        get_ptr(state2),
        ct.c_float(beta1),
        ct.c_float(beta2),
        ct.c_float(eps),
        ct.c_int32(step),
        ct.c_float(lr),
        get_ptr(qmap1),
        get_ptr(qmap2),
        get_ptr(absmax1),
        get_ptr(absmax2),
        ct.c_float(weight_decay),
        ct.c_float(gnorm_scale),
        ct.c_bool(skip_zeros),
        ct.c_int32(g.numel()),
    )
    post_call(prev_device)
Tim Dettmers's avatar
Tim Dettmers committed
1231

1232
1233
1234
def percentile_clipping(
    grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5
):
Tim Dettmers's avatar
Tim Dettmers committed
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
    """Applies percentile clipping

    grad: torch.Tensor
        The gradient tensor.
    gnorm_vec: torch.Tensor
        Vector of gradient norms. 100 elements expected.
    step: int
        The current optimiation steps (number of past gradient norms).

    """
1245
    is_on_gpu([grad, gnorm_vec])
Tim Dettmers's avatar
Tim Dettmers committed
1246
    if grad.dtype == torch.float32:
1247
1248
1249
1250
1251
1252
        lib.cpercentile_clipping_g32(
            get_ptr(grad),
            get_ptr(gnorm_vec),
            ct.c_int32(step),
            ct.c_int32(grad.numel()),
        )
Tim Dettmers's avatar
Tim Dettmers committed
1253
    elif grad.dtype == torch.float16:
1254
1255
1256
1257
1258
1259
        lib.cpercentile_clipping_g16(
            get_ptr(grad),
            get_ptr(gnorm_vec),
            ct.c_int32(step),
            ct.c_int32(grad.numel()),
        )
Tim Dettmers's avatar
Tim Dettmers committed
1260
    else:
1261
        raise ValueError(f"Gradient type {grad.dtype} not supported!")
Tim Dettmers's avatar
Tim Dettmers committed
1262
1263
1264
1265
1266
1267
1268

    current_gnorm = torch.sqrt(gnorm_vec[step % 100])
    vals, idx = torch.sort(gnorm_vec)
    clip_value = torch.sqrt(vals[percentile])
    gnorm_scale = 1.0

    if current_gnorm > clip_value:
1269
        gnorm_scale = clip_value / current_gnorm
Tim Dettmers's avatar
Tim Dettmers committed
1270
1271
1272
1273

    return current_gnorm, clip_value, gnorm_scale


1274
1275
1276
def histogram_scatter_add_2d(
    histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor
):
Tim Dettmers's avatar
Tim Dettmers committed
1277
1278
1279
1280
1281
1282
    assert len(histogram.shape) == 2
    assert histogram.dtype == torch.float32
    assert source.dtype == torch.float32
    assert index1.dtype == torch.int32
    assert index2.dtype == torch.int32

1283
1284
1285
1286
    assert histogram.device.type == "cuda"
    assert index1.device.type == "cuda"
    assert index2.device.type == "cuda"
    assert source.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
1287
1288
1289

    maxdim1 = ct.c_int32(histogram.shape[0])
    n = ct.c_int32(index1.numel())
1290
    is_on_gpu([histogram, index1, index2, source])
Tim Dettmers's avatar
Tim Dettmers committed
1291
    lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
1292

Tim Dettmers's avatar
Tim Dettmers committed
1293
1294
1295
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
    if not torch.cuda.is_initialized(): torch.cuda.init()
    if A.dtype != expected_type or B.dtype != expected_type:
1296
1297
1298
        raise TypeError(
            f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}"
        )
Tim Dettmers's avatar
Tim Dettmers committed
1299
1300
1301
1302
1303
1304
1305
1306
1307

    sA = A.shape
    sB = B.shape
    tA = transposed_A
    tB = transposed_B

    correct = True

    if len(sA) == 2 and len(sB) == 2:
1308
1309
1310
1311
1312
1313
1314
1315
        if not tA and not tB and A.shape[1] != B.shape[0]:
            correct = False
        elif tA and not tB and A.shape[0] != B.shape[0]:
            correct = False
        elif tA and tB and A.shape[0] != B.shape[1]:
            correct = False
        elif not tA and tB and A.shape[1] != B.shape[1]:
            correct = False
Tim Dettmers's avatar
Tim Dettmers committed
1316
    elif len(sA) == 3 and len(sB) == 2:
1317
1318
1319
1320
1321
1322
1323
1324
        if not tA and not tB and A.shape[2] != B.shape[0]:
            correct = False
        elif tA and not tB and A.shape[1] != B.shape[0]:
            correct = False
        elif tA and tB and A.shape[1] != B.shape[1]:
            correct = False
        elif not tA and tB and A.shape[2] != B.shape[1]:
            correct = False
Tim Dettmers's avatar
Tim Dettmers committed
1325
    elif len(sA) == 3 and len(sB) == 3:
1326
1327
1328
1329
1330
1331
1332
1333
        if not tA and not tB and A.shape[2] != B.shape[1]:
            correct = False
        elif tA and not tB and A.shape[1] != B.shape[1]:
            correct = False
        elif tA and tB and A.shape[1] != B.shape[2]:
            correct = False
        elif not tA and tB and A.shape[2] != B.shape[2]:
            correct = False
Tim Dettmers's avatar
Tim Dettmers committed
1334
1335
1336
1337
1338

    if out is not None:
        sout = out.shape
        # special case common in backprop
        if not correct and len(sA) == 3 and len(sB) == 3:
1339
1340
1341
1342
1343
1344
            if (
                sout[0] == sA[2]
                and sout[1] == sB[2]
                and sA[0] == sB[0]
                and sA[1] == sB[1]
            ):
Tim Dettmers's avatar
Tim Dettmers committed
1345
1346
1347
                correct = True
    else:
        if len(sA) == 2 and len(sB) == 2:
1348
1349
1350
1351
1352
1353
1354
1355
            if not tA and not tB:
                sout = (sA[0], sB[1])
            elif tA and tB:
                sout = (sA[1], sB[0])
            elif tA and not tB:
                sout = (sA[1], sB[1])
            elif not tA and tB:
                sout = (sA[0], sB[0])
Tim Dettmers's avatar
Tim Dettmers committed
1356
        elif len(sA) == 3 and len(sB) == 2:
1357
1358
1359
1360
1361
1362
1363
1364
            if not tA and not tB:
                sout = (sA[0], sA[1], sB[1])
            elif tA and tB:
                sout = (sA[0], sA[2], sB[0])
            elif tA and not tB:
                sout = (sA[0], sA[2], sB[1])
            elif not tA and tB:
                sout = (sA[0], sA[1], sB[0])
Tim Dettmers's avatar
Tim Dettmers committed
1365
        elif len(sA) == 3 and len(sB) == 3:
1366
1367
1368
1369
1370
1371
1372
1373
            if not tA and not tB:
                sout = (sA[0], sA[1], sB[2])
            elif tA and tB:
                sout = (sA[0], sA[2], sB[1])
            elif tA and not tB:
                sout = (sA[0], sA[2], sB[2])
            elif not tA and tB:
                sout = (sA[0], sA[1], sB[1])
Tim Dettmers's avatar
Tim Dettmers committed
1374
1375

    if not correct:
1376
1377
1378
        raise ValueError(
            f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}."
        )
Tim Dettmers's avatar
Tim Dettmers committed
1379
1380
1381

    return sout

1382
1383

def igemm(
1384
1385
1386
1387
1388
    A: Tensor,
    B: Tensor,
    out: Tensor = None,
    transposed_A=False,
    transposed_B=False,
1389
):
Tim Dettmers's avatar
Tim Dettmers committed
1390
    sout = check_matmul(A, B, out, transposed_A, transposed_B)
1391
1392
    if out is None:
        out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
Tim Dettmers's avatar
Tim Dettmers committed
1393
1394
1395
1396
1397
1398
    if len(A.shape) == 3 and len(B.shape) == 3:
        if A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]:
            return batched_igemm(A, B, out)

    sA = A.shape
    sB = B.shape
1399
1400
1401
1402
1403
1404
1405
1406
    if transposed_A and len(sA) == 2:
        sA = (sA[1], sA[0])
    elif transposed_A and len(sA) == 3:
        sA = (sA[0], sA[2], sA[0])
    if transposed_B and len(sB) == 2:
        sB = (sB[1], sB[0])
    elif transposed_B and len(sB) == 3:
        sB = (sB[0], sB[2], sB[0])
Tim Dettmers's avatar
Tim Dettmers committed
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
    # this is a mess: cuBLAS expect column major, but PyTorch is row major.
    # So to perform the matrix multiplication, we have to treat A, B, and C matrices
    # (transpose of row major is column major)
    # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these

    # matrices in the input arguments for cuBLAS
    # column major: A @ B = C: [m, k] @ [k, n] = [m, n]
    # row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
    # column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
    if len(sB) == 2:
1417
1418
1419
1420
        if B.stride()[0] == B.shape[1]:
            transposed_B = False
        elif B.stride()[1] == B.shape[0]:
            transposed_B = True
Tim Dettmers's avatar
Tim Dettmers committed
1421
        if len(A.shape) == 2:
1422
1423
1424
1425
            if A.stride()[0] == A.shape[1]:
                transposed_A = False
            elif A.stride()[1] == A.shape[0]:
                transposed_A = True
Tim Dettmers's avatar
Tim Dettmers committed
1426
        else:
1427
1428
1429
1430
            if A.stride()[1] == A.shape[2]:
                transposed_A = False
            elif A.stride()[2] == A.shape[1]:
                transposed_A = True
Tim Dettmers's avatar
Tim Dettmers committed
1431
1432
1433
1434
1435

        if len(sA) == 2:
            n = sA[0]
            ldb = A.stride()[1 if transposed_A else 0]
        elif len(sA) == 3 and len(sB) == 2:
1436
            n = sA[0] * sA[1]
Tim Dettmers's avatar
Tim Dettmers committed
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
            ldb = sA[2]

        m = sB[1]
        k = sB[0]
        lda = B.stride()[(1 if transposed_B else 0)]
        ldc = sB[1]
    elif len(sB) == 3:
        # special case
        assert len(sA) == 3
        if not (sA[0] == sB[0] and sA[1] == sB[1]):
1447
1448
1449
            raise ValueError(
                f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
1450
1451
1452
1453
1454
1455

        transposed_A = True
        transposed_B = False

        m = sB[2]
        n = sA[2]
1456
        k = sB[0] * sB[1]
Tim Dettmers's avatar
Tim Dettmers committed
1457
1458
1459
1460
1461
1462
1463
1464

        lda = m
        ldb = sA[2]
        ldc = m

    ptr = CUBLAS_Context.get_instance().get_context(A.device)

    # B^T @ A^T = C^T
1465
    # [km, nk -> mn]
1466
    is_on_gpu([B, A, out])
Tim Dettmers's avatar
Tim Dettmers committed
1467
1468
1469
1470
1471
    lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
               get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))
    return out


1472
def batched_igemm(
1473
1474
1475
1476
1477
    A: Tensor,
    B: Tensor,
    out: Tensor = None,
    transposed_A=False,
    transposed_B=False,
1478
):
Tim Dettmers's avatar
Tim Dettmers committed
1479
    if not len(A.shape) == 3 or not len(B.shape) == 3:
1480
1481
1482
        raise ValueError(
            f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}"
        )
Tim Dettmers's avatar
Tim Dettmers committed
1483
    sout = check_matmul(A, B, out, transposed_A, transposed_B)
1484
1485
    if out is None:
        out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
Tim Dettmers's avatar
Tim Dettmers committed
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541

    if B.is_contiguous():
        lda = B.stride()[1]
        transposed_A = False
    else:
        s = B.stride()
        if s[0] != B.shape[0]:
            B = B.contiguous()
            lda = B.stride()[1]
        elif s[2] == B.shape[1]:
            transposed_A = True
            lda = B.stride()[2]
        else:
            if s[2] == 1:
                B = B.contiguous()
                lda = B.stride()[1]
            elif s[1] == 1:
                B = B.contiguous()
                lda = B.stride()[1]
            else:
                B = B.contiguous()
                lda = B.stride()[1]

    if A.is_contiguous():
        ldb = A.stride()[1]
        transposed_B = False
    else:
        s = A.stride()
        if s[0] != A.shape[0]:
            A = A.contiguous()
            ldb = A.stride()[1]
            transposed_B = False
        elif s[2] == A.shape[1]:
            ldb = A.stride()[2]
            transposed_B = True
        else:
            A = A.contiguous()
            ldb = A.stride()[1]
            transposed_B = False

    # this is a mess: cuBLAS expect column major, but PyTorch is row major.
    # So to perform the matrix multiplication, we have to treat A, B, and C matrices
    # (transpose of row major is column major)
    # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
    # matrices in the input arguments for cuBLAS

    # column major: A @ B = C: [batch, m, k] @ [batch, k, n] = [batch, m, n]
    # row major: B^T @ A^T = C^T: [batch, m, k] @ [batch, k, n] = [batch, m, n]
    # column major with row major layout: B^T @ A^T = C^T: [batch, k, m] @ [batch, n, k] = [batch, n, m]
    num_batch = A.shape[0]
    n = A.shape[1]
    m = B.shape[2]
    k = B.shape[1]

    ldc = m

1542
1543
1544
    strideA = B.shape[1] * B.shape[2]
    strideB = A.shape[1] * A.shape[2]
    strideC = A.shape[1] * B.shape[2]
Tim Dettmers's avatar
Tim Dettmers committed
1545
1546
1547

    ptr = CUBLAS_Context.get_instance().get_context(A.device)

1548
    is_on_gpu([B, A, out])
Tim Dettmers's avatar
Tim Dettmers committed
1549
1550
1551
1552
1553
    lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
               get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc),
               ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
    return out

1554

1555
def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
Tim Dettmers's avatar
Tim Dettmers committed
1556
1557
1558
1559
    shapeA = SA[0]
    shapeB = SB[0]
    dimsA = len(shapeA)
    dimsB = len(shapeB)
1560
    assert dimsB == 2, 'Only two dimensional matrices are supported for argument B'
Tim Dettmers's avatar
Tim Dettmers committed
1561
1562
1563
    if dimsA == 2:
        m = shapeA[0]
    elif dimsA == 3:
1564
        m = shapeA[0] * shapeA[1]
Tim Dettmers's avatar
Tim Dettmers committed
1565

1566
    rows = n = shapeB[0]
1567
    assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
1568
1569
1570
1571
1572
1573

    # if the tensor is empty, return a transformed empty tensor with the right dimensions
    if shapeA[0] == 0 and dimsA == 2:
        return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16)
    elif shapeA[1] == 0 and dimsA == 3:
        return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
1574
1575

    if dimsA == 2 and out is None:
1576
1577
1578
        out, Sout = get_transform_buffer(
            (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row"
        )
Tim Dettmers's avatar
Tim Dettmers committed
1579
    elif dimsA == 3 and out is None:
1580
1581
1582
        out, Sout = get_transform_buffer(
            (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row"
        )
Tim Dettmers's avatar
Tim Dettmers committed
1583

1584
1585
1586
    assert dimsB != 3, "len(B.shape)==3 not supported"
    assert A.device.type == "cuda"
    assert B.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
1587
1588
1589
    assert A.dtype == torch.int8
    assert B.dtype == torch.int8
    assert out.dtype == dtype
1590
1591
1592
1593
1594
1595
    assert SA[1] == "col32"
    assert SB[1] in ["col_turing", "col_ampere"]
    assert Sout[1] == "col32"
    assert (
        shapeA[-1] == shapeB[-1]
    ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}"
Tim Dettmers's avatar
Tim Dettmers committed
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
    formatB = SB[1]
    prev_device = A.device
    torch.cuda.set_device(A.device)

    ptr = CUBLAS_Context.get_instance().get_context(A.device)
    ptrA = get_ptr(A)
    ptrB = get_ptr(B)
    ptrC = get_ptr(out)

    k = shapeA[-1]
1606
1607
    lda = ct.c_int32(m * 32)
    if formatB == "col_turing":
Tim Dettmers's avatar
Tim Dettmers committed
1608
1609
        # turing: tiles with rows filled up to multiple of 8 rows by 32 columns
        # n = rows
1610
        ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32)
Tim Dettmers's avatar
Tim Dettmers committed
1611
1612
1613
    else:
        # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
        # n = rows
1614
        ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32)
Tim Dettmers's avatar
Tim Dettmers committed
1615

1616
    ldc = ct.c_int32(m * 32)
Tim Dettmers's avatar
Tim Dettmers committed
1617
1618
1619
1620
1621
    m = ct.c_int32(m)
    n = ct.c_int32(n)
    k = ct.c_int32(k)

    has_error = 0
1622
    ptrRowScale = get_ptr(None)
1623
    is_on_gpu([A, B, out])
Tim Dettmers's avatar
Tim Dettmers committed
1624
1625
    if formatB == 'col_turing':
        if dtype == torch.int32:
1626
1627
1628
            has_error = lib.cigemmlt_turing_32(
                ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
            )
Tim Dettmers's avatar
Tim Dettmers committed
1629
        else:
1630
1631
1632
1633
            has_error = lib.cigemmlt_turing_8(
                ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
            )
    elif formatB == "col_ampere":
Tim Dettmers's avatar
Tim Dettmers committed
1634
        if dtype == torch.int32:
1635
1636
1637
            has_error = lib.cigemmlt_ampere_32(
                ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
            )
Tim Dettmers's avatar
Tim Dettmers committed
1638
        else:
1639
1640
1641
            has_error = lib.cigemmlt_ampere_8(
                ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
            )
Tim Dettmers's avatar
Tim Dettmers committed
1642
1643

    if has_error == 1:
1644
        print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
Tim Dettmers's avatar
Tim Dettmers committed
1645
1646
1647
1648
1649
1650
1651
        raise Exception('cublasLt ran into an error!')

    torch.cuda.set_device(prev_device)

    return out, Sout


1652
1653
1654
1655
1656
1657
1658
1659
def mm_dequant(
    A,
    quant_state,
    row_stats,
    col_stats,
    out=None,
    new_row_stats=None,
    new_col_stats=None,
1660
    bias=None
1661
):
Tim Dettmers's avatar
Tim Dettmers committed
1662
    assert A.dtype == torch.int32
1663
    if bias is not None: assert bias.dtype == torch.float16
Tim Dettmers's avatar
Tim Dettmers committed
1664
    out_shape = quant_state[0]
1665
1666
1667
1668
1669
1670
    if len(out_shape) == 3:
        out_shape = (out_shape[0] * out_shape[1], out_shape[2])

    if out is None:
        out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
    if new_row_stats is None:
1671
1672
1673
        new_row_stats = torch.empty(
            out_shape[0], dtype=torch.float32, device=A.device
        )
1674
    if new_col_stats is None:
1675
1676
1677
        new_col_stats = torch.empty(
            out_shape[1], dtype=torch.float32, device=A.device
        )
1678
1679
1680
1681
1682
1683
    assert (
        new_row_stats.shape[0] == row_stats.shape[0]
    ), f"{new_row_stats.shape} vs {row_stats.shape}"
    assert (
        new_col_stats.shape[0] == col_stats.shape[0]
    ), f"{new_col_stats.shape} vs {col_stats.shape}"
Tim Dettmers's avatar
Tim Dettmers committed
1684

1685
    prev_device = pre_call(A.device)
Tim Dettmers's avatar
Tim Dettmers committed
1686
1687
1688
1689
1690
1691
    ptrA = get_ptr(A)
    ptrOut = get_ptr(out)
    ptrRowStats = get_ptr(row_stats)
    ptrColStats = get_ptr(col_stats)
    ptrNewRowStats = get_ptr(new_row_stats)
    ptrNewColStats = get_ptr(new_col_stats)
1692
    ptrBias = get_ptr(bias)
Tim Dettmers's avatar
Tim Dettmers committed
1693
1694
1695
    numRows = ct.c_int32(out_shape[0])
    numCols = ct.c_int32(out_shape[1])

1696
1697
1698
    is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias])
    lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols)
    post_call(prev_device)
Tim Dettmers's avatar
Tim Dettmers committed
1699
1700
1701
1702

    return out


1703
1704
1705
def get_colrow_absmax(
    A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0
):
Tim Dettmers's avatar
Tim Dettmers committed
1706
1707
1708
1709
1710
    assert A.dtype == torch.float16
    device = A.device

    cols = A.shape[-1]
    if len(A.shape) == 3:
1711
        rows = A.shape[0] * A.shape[1]
Tim Dettmers's avatar
Tim Dettmers committed
1712
1713
1714
    else:
        rows = A.shape[0]

1715
1716
1717
    col_tiles = (cols + 255) // 256
    tiled_rows = ((rows + 15) // 16) * 16
    if row_stats is None:
1718
1719
1720
        row_stats = torch.empty(
            (rows,), dtype=torch.float32, device=device
        ).fill_(-50000.0)
1721
    if col_stats is None:
1722
1723
1724
        col_stats = torch.empty(
            (cols,), dtype=torch.float32, device=device
        ).fill_(-50000.0)
1725
1726
1727
1728
1729

    if nnz_block_ptr is None and threshold > 0.0:
        nnz_block_ptr = torch.zeros(
            ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device
        )
Tim Dettmers's avatar
Tim Dettmers committed
1730
1731
1732
1733
1734
1735
1736
1737
1738

    ptrA = get_ptr(A)
    ptrRowStats = get_ptr(row_stats)
    ptrColStats = get_ptr(col_stats)
    ptrNnzrows = get_ptr(nnz_block_ptr)
    rows = ct.c_int32(rows)
    cols = ct.c_int32(cols)

    prev_device = pre_call(A.device)
1739
    is_on_gpu([A, row_stats, col_stats, nnz_block_ptr])
Tim Dettmers's avatar
Tim Dettmers committed
1740
1741
1742
1743
1744
1745
1746
1747
    lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols)
    post_call(prev_device)

    if threshold > 0.0:
        nnz_block_ptr.cumsum_(0)

    return row_stats, col_stats, nnz_block_ptr

1748

1749
class COOSparseTensor:
Tim Dettmers's avatar
Tim Dettmers committed
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
    def __init__(self, rows, cols, nnz, rowidx, colidx, values):
        assert rowidx.dtype == torch.int32
        assert colidx.dtype == torch.int32
        assert values.dtype == torch.float16
        assert values.numel() == nnz
        assert rowidx.numel() == nnz
        assert colidx.numel() == nnz

        self.rows = rows
        self.cols = cols
        self.nnz = nnz
        self.rowidx = rowidx
        self.colidx = colidx
        self.values = values

1765

1766
class CSRSparseTensor:
Tim Dettmers's avatar
Tim Dettmers committed
1767
1768
1769
1770
1771
1772
    def __init__(self, rows, cols, nnz, rowptr, colidx, values):
        assert rowptr.dtype == torch.int32
        assert colidx.dtype == torch.int32
        assert values.dtype == torch.float16
        assert values.numel() == nnz
        assert colidx.numel() == nnz
1773
        assert rowptr.numel() == rows + 1
Tim Dettmers's avatar
Tim Dettmers committed
1774
1775
1776
1777
1778
1779
1780
1781

        self.rows = rows
        self.cols = cols
        self.nnz = nnz
        self.rowptr = rowptr
        self.colidx = colidx
        self.values = values

1782

1783
class CSCSparseTensor:
Tim Dettmers's avatar
Tim Dettmers committed
1784
1785
1786
1787
1788
1789
    def __init__(self, rows, cols, nnz, colptr, rowidx, values):
        assert colptr.dtype == torch.int32
        assert rowidx.dtype == torch.int32
        assert values.dtype == torch.float16
        assert values.numel() == nnz
        assert rowidx.numel() == nnz
1790
        assert colptr.numel() == cols + 1
Tim Dettmers's avatar
Tim Dettmers committed
1791
1792
1793
1794
1795
1796
1797
1798

        self.rows = rows
        self.cols = cols
        self.nnz = nnz
        self.colptr = colptr
        self.rowidx = rowidx
        self.values = values

1799

Tim Dettmers's avatar
Tim Dettmers committed
1800
1801
1802
def coo2csr(cooA):
    values, counts = torch.unique(cooA.rowidx, return_counts=True)
    values.add_(1)
1803
1804
1805
    rowptr = torch.zeros(
        (cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device
    )
Tim Dettmers's avatar
Tim Dettmers committed
1806
1807
    rowptr.scatter_(index=values.long(), src=counts.int(), dim=0)
    rowptr.cumsum_(0)
1808
1809
1810
1811
    return CSRSparseTensor(
        cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values
    )

Tim Dettmers's avatar
Tim Dettmers committed
1812
1813
1814
1815
1816
1817
1818

def coo2csc(cooA):
    val, col2rowidx = torch.sort(cooA.colidx)
    rowidx = cooA.rowidx[col2rowidx]
    values = cooA.values[col2rowidx]
    colvalues, counts = torch.unique(val, return_counts=True)
    colvalues.add_(1)
1819
1820
1821
    colptr = torch.zeros(
        (cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device
    )
Tim Dettmers's avatar
Tim Dettmers committed
1822
1823
    colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0)
    colptr.cumsum_(0)
1824
1825
1826
    return CSCSparseTensor(
        cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values
    )
Tim Dettmers's avatar
Tim Dettmers committed
1827

1828

Tim Dettmers's avatar
Tim Dettmers committed
1829
1830
1831
1832
1833
1834
1835
def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
    rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device)
    colidx = torch.zeros((nnz,), dtype=torch.int32, device=device)
    values = torch.zeros((nnz,), dtype=dtype, device=device)
    return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)


1836
1837
1838
def double_quant(
    A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0
):
Tim Dettmers's avatar
Tim Dettmers committed
1839
1840
    device = A.device
    assert A.dtype == torch.half
1841
    assert device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
1842
1843
1844
1845
    prev_device = pre_call(A.device)

    cols = A.shape[-1]
    if len(A.shape) == 3:
1846
        rows = A.shape[0] * A.shape[1]
Tim Dettmers's avatar
Tim Dettmers committed
1847
1848
1849
1850
    else:
        rows = A.shape[0]

    if row_stats is None or col_stats is None:
1851
1852
1853
        row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(
            A, threshold=threshold
        )
Tim Dettmers's avatar
Tim Dettmers committed
1854

1855
1856
1857
1858
    if out_col is None:
        out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
    if out_row is None:
        out_row = torch.zeros(A.shape, device=device, dtype=torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
1859
1860
1861
1862
1863
1864
1865
1866

    coo_tensor = None
    ptrA = get_ptr(A)
    ptrColStats = get_ptr(col_stats)
    ptrRowStats = get_ptr(row_stats)
    ptrOutCol = get_ptr(out_col)
    ptrOutRow = get_ptr(out_row)

1867
    is_on_gpu([A, col_stats, row_stats, out_col, out_row])
Tim Dettmers's avatar
Tim Dettmers committed
1868
1869
1870
    if threshold > 0.0:
        nnz = nnz_row_ptr[-1].item()
        if nnz > 0:
1871
1872
1873
            coo_tensor = coo_zeros(
                A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device
            )
Tim Dettmers's avatar
Tim Dettmers committed
1874
1875
1876
1877
1878
            ptrRowIdx = get_ptr(coo_tensor.rowidx)
            ptrColIdx = get_ptr(coo_tensor.colidx)
            ptrVal = get_ptr(coo_tensor.values)
            ptrRowPtr = get_ptr(nnz_row_ptr)

1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
            lib.cdouble_rowcol_quant(
                ptrA,
                ptrRowStats,
                ptrColStats,
                ptrOutCol,
                ptrOutRow,
                ptrRowIdx,
                ptrColIdx,
                ptrVal,
                ptrRowPtr,
                ct.c_float(threshold),
                ct.c_int32(rows),
                ct.c_int32(cols),
            )
Tim Dettmers's avatar
Tim Dettmers committed
1893
1894
1895
1896
1897
            val, idx = torch.sort(coo_tensor.rowidx)
            coo_tensor.rowidx = val
            coo_tensor.colidx = coo_tensor.colidx[idx]
            coo_tensor.values = coo_tensor.values[idx]
        else:
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
            lib.cdouble_rowcol_quant(
                ptrA,
                ptrRowStats,
                ptrColStats,
                ptrOutCol,
                ptrOutRow,
                None,
                None,
                None,
                None,
                ct.c_float(0.0),
                ct.c_int32(rows),
                ct.c_int32(cols),
            )
Tim Dettmers's avatar
Tim Dettmers committed
1912
    else:
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
        lib.cdouble_rowcol_quant(
            ptrA,
            ptrRowStats,
            ptrColStats,
            ptrOutCol,
            ptrOutRow,
            None,
            None,
            None,
            None,
            ct.c_float(threshold),
            ct.c_int32(rows),
            ct.c_int32(cols),
        )
Tim Dettmers's avatar
Tim Dettmers committed
1927
1928
1929
1930
1931
1932
    post_call(prev_device)

    return out_row, out_col, row_stats, col_stats, coo_tensor


def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
1933
    prev_device = pre_call(A.device)
Tim Dettmers's avatar
Tim Dettmers committed
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
    if state is None: state = (A.shape, from_order)
    else: from_order = state[1]
    if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
    else: new_state = (state[0], to_order) # (shape, order)

    shape = state[0]
    if len(shape) == 2:
        dim1 = ct.c_int32(shape[0])
        dim2 = ct.c_int32(shape[1])
    else:
1944
        dim1 = ct.c_int32(shape[0] * shape[1])
Tim Dettmers's avatar
Tim Dettmers committed
1945
1946
        dim2 = ct.c_int32(shape[2])

1947
    is_on_gpu([A, out])
Tim Dettmers's avatar
Tim Dettmers committed
1948
1949
1950
1951
1952
    if to_order == 'col32':
        if transpose:
            lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
        else:
            lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2)
1953
    elif to_order == "col_turing":
Tim Dettmers's avatar
Tim Dettmers committed
1954
1955
1956
1957
        if transpose:
            lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2)
        else:
            lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2)
1958
    elif to_order == "col_ampere":
Tim Dettmers's avatar
Tim Dettmers committed
1959
1960
1961
1962
        if transpose:
            lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2)
        else:
            lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2)
1963
1964
    elif to_order == "row":
        if from_order == "col_turing":
Tim Dettmers's avatar
Tim Dettmers committed
1965
            lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2)
1966
        elif from_order == "col_ampere":
Tim Dettmers's avatar
Tim Dettmers committed
1967
1968
1969
1970
            lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
    else:
        raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')

1971
    post_call(prev_device)
Tim Dettmers's avatar
Tim Dettmers committed
1972
1973
1974

    return out, new_state

1975

Tim Dettmers's avatar
Tim Dettmers committed
1976
def spmm_coo(cooA, B, out=None):
1977
    if out is None:
1978
1979
1980
        out = torch.empty(
            (cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype
        )
Tim Dettmers's avatar
Tim Dettmers committed
1981
1982
1983
1984
1985
1986
    nnz = cooA.nnz
    assert cooA.rowidx.numel() == nnz
    assert cooA.colidx.numel() == nnz
    assert cooA.values.numel() == nnz
    assert cooA.cols == B.shape[0]

1987
    transposed_B = False if B.is_contiguous() else True
Tim Dettmers's avatar
Tim Dettmers committed
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005

    ldb = B.stride()[(1 if transposed_B else 0)]
    ldc = B.shape[1]

    ptr = Cusparse_Context.get_instance().context

    ptrRowidx = get_ptr(cooA.rowidx)
    ptrColidx = get_ptr(cooA.colidx)
    ptrValues = get_ptr(cooA.values)
    ptrB = get_ptr(B)
    ptrC = get_ptr(out)
    cnnz = ct.c_int32(cooA.nnz)
    crowsA = ct.c_int32(cooA.rows)
    ccolsA = ct.c_int32(cooA.cols)
    ccolsB = ct.c_int32(B.shape[1])
    cldb = ct.c_int32(ldb)
    cldc = ct.c_int32(ldc)

2006
    is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out])
Tim Dettmers's avatar
Tim Dettmers committed
2007
2008
2009
2010
    lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B))

    return out

2011

Tim Dettmers's avatar
Tim Dettmers committed
2012
def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
2013
2014
2015
2016
    if out is None:
        out = torch.zeros(
            (cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype
        )
Tim Dettmers's avatar
Tim Dettmers committed
2017
2018
2019
2020
    nnz = cooA.nnz
    assert cooA.rowidx.numel() == nnz
    assert cooA.colidx.numel() == nnz
    assert cooA.values.numel() == nnz
2021
    assert cooA.cols == B.shape[0], f"{cooA.cols} vs {B.shape}"
Tim Dettmers's avatar
Tim Dettmers committed
2022

2023
    transposed_B = False if B.is_contiguous() else True
Tim Dettmers's avatar
Tim Dettmers committed
2024
2025
2026
2027
2028
2029
2030
2031
2032

    ldb = B.stride()[(1 if transposed_B else 0)]
    ldc = B.shape[1]

    values, counts = torch.unique(cooA.rowidx, return_counts=True)
    offset = counts.cumsum(0).int()
    max_count, max_idx = torch.sort(counts, descending=True)
    max_idx = max_idx.int()
    max_count = max_count.int()
2033
2034
2035
    assert (
        max_count[0] <= 32
    ), f"Current max count per row is 8 but found {max_count[0]}."
Tim Dettmers's avatar
Tim Dettmers committed
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
    assert B.dtype in [torch.float16, torch.int8]
    ptrOffset = get_ptr(offset)
    ptrMaxCount = get_ptr(max_count)
    ptrMaxIdx = get_ptr(max_idx)

    ptrRowidx = get_ptr(cooA.rowidx)
    ptrColidx = get_ptr(cooA.colidx)
    ptrValues = get_ptr(cooA.values)
    ptrB = get_ptr(B)
    ptrC = get_ptr(out)
    ptrDequantStats = get_ptr(dequant_stats)
    cnnz_rows = ct.c_int32(counts.numel())
    cnnz = ct.c_int32(cooA.nnz)
    crowsA = ct.c_int32(cooA.rows)
    ccolsA = ct.c_int32(cooA.cols)
    crowsB = ct.c_int32(B.shape[1])
    ccolsB = ct.c_int32(B.shape[1])
    cldb = ct.c_int32(ldb)
    cldc = ct.c_int32(ldc)

2056
    is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
Tim Dettmers's avatar
Tim Dettmers committed
2057
    if B.dtype == torch.float16:
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
        lib.cspmm_coo_very_sparse_naive_fp16(
            ptrMaxCount,
            ptrMaxIdx,
            ptrOffset,
            ptrRowidx,
            ptrColidx,
            ptrValues,
            ptrB,
            ptrC,
            ptrDequantStats,
            cnnz_rows,
            cnnz,
            crowsA,
            crowsB,
            ccolsB,
        )
Tim Dettmers's avatar
Tim Dettmers committed
2074
    elif B.dtype == torch.int8:
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
        lib.cspmm_coo_very_sparse_naive_int8(
            ptrMaxCount,
            ptrMaxIdx,
            ptrOffset,
            ptrRowidx,
            ptrColidx,
            ptrValues,
            ptrB,
            ptrC,
            ptrDequantStats,
            cnnz_rows,
            cnnz,
            crowsA,
            crowsB,
            ccolsB,
        )
    # else: assertion error
Tim Dettmers's avatar
Tim Dettmers committed
2092
2093
2094
2095
2096
2097

    return out


C = 127.0

2098
2099
2100

def vectorwise_quant(x, dim=1, quant_type="vector"):
    if quant_type == "linear":
Tim Dettmers's avatar
Tim Dettmers committed
2101
        max1 = torch.abs(x).max().float()
2102
        xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
2103
        return xq, max1
2104
    elif quant_type in ["vector", "row"]:
Tim Dettmers's avatar
Tim Dettmers committed
2105
        max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
2106
        xq = torch.round(x * (C / max1)).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
2107
        return xq, max1
2108
    elif quant_type == "zeropoint":
Tim Dettmers's avatar
Tim Dettmers committed
2109
2110
2111
        dtype = x.dtype
        x = x.float()
        dyna = x.max() - x.min()
2112
2113
2114
        if dyna == 0:
            dyna = 1
        qx = 255.0 / dyna
Tim Dettmers's avatar
Tim Dettmers committed
2115
        minx = x.min()
2116
2117
        zpx = torch.round(minx * qx)
        x = torch.round(qx * x - zpx) + zpx
Tim Dettmers's avatar
Tim Dettmers committed
2118
        return x, qx
2119
    elif quant_type in ["vector-zeropoint", "row-zeropoint"]:
Tim Dettmers's avatar
Tim Dettmers committed
2120
2121
        dtype = x.dtype
        x = x.float()
2122
2123
2124
2125
2126
        dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(
            x, dim=dim, keepdim=True
        )
        dyna[dyna == 0] = 1
        qx = 255.0 / dyna
Tim Dettmers's avatar
Tim Dettmers committed
2127
        minx = torch.amin(x, dim=dim, keepdim=True)
2128
2129
        zpx = torch.round(minx * qx)
        x = torch.round(qx * x - zpx) + zpx
Tim Dettmers's avatar
Tim Dettmers committed
2130
        return x, qx
2131
    elif quant_type == "truncated-vector":
Tim Dettmers's avatar
Tim Dettmers committed
2132
2133
2134
        with torch.no_grad():
            absx = torch.abs(x)
            max1 = torch.amax(absx, dim=dim, keepdim=True)
2135
2136
            max1 = max1 * 0.7
            idx = absx > max1.expand_as(absx)
Tim Dettmers's avatar
Tim Dettmers committed
2137
            sign = torch.sign(x[idx])
2138
2139
            x[idx] = max1.expand_as(absx)[idx] * sign
            xq = torch.round(x / max1 * C).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
2140
        return xq, max1
2141
2142
2143
    else:
        return None

Tim Dettmers's avatar
Tim Dettmers committed
2144

2145
2146
2147
def vectorwise_dequant(xq, max1, quant_type="vector"):
    if quant_type == "vector":
        x = (xq / C * max1).to(torch.float32)
Tim Dettmers's avatar
Tim Dettmers committed
2148
        return x
2149
2150
    else:
        return None
Tim Dettmers's avatar
Tim Dettmers committed
2151

2152
2153
2154
2155

def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
    if quant_type == "linear":
        norm = S1 * S2 / (C * C)
Tim Dettmers's avatar
Tim Dettmers committed
2156
        # double cast needed to prevent overflows
2157
2158
2159
2160
2161
2162
        return (xq.float() * norm).to(dtype)
    elif quant_type == "zeropoint":
        norm = 1.0 / (S1 * S2)
        return (xq.float() * norm).to(dtype)
    elif quant_type == "row-zeropoint":
        norm = 1.0 / (S1 * S2)
Tim Dettmers's avatar
Tim Dettmers committed
2163
        x = xq.float()
2164
2165
2166
2167
        if len(S1.shape) == 3 and len(x.shape) == 2:
            S1 = S1.squeeze(0)
        if len(S2.shape) == 3 and len(x.shape) == 2:
            S2 = S2.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
2168
2169
2170
2171
2172
        if len(S1.shape) == 2:
            x *= norm
        else:
            x *= norm
        return x.to(dtype)
2173
    elif quant_type == "vector-zeropoint":
Tim Dettmers's avatar
Tim Dettmers committed
2174
        x = xq.float()
2175
2176
2177
2178
        if len(S1.shape) == 3 and len(x.shape) == 2:
            S1 = S1.squeeze(0)
        if len(S2.shape) == 3 and len(x.shape) == 2:
            S2 = S2.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
2179
        if len(S1.shape) == 2:
2180
            x *= 1.0 / S1
Tim Dettmers's avatar
Tim Dettmers committed
2181
        else:
2182
2183
            x *= 1.0 / S1
        x *= 1.0 / S2.t()
Tim Dettmers's avatar
Tim Dettmers committed
2184
        return x.to(dtype)
2185
    elif quant_type == "row":
Tim Dettmers's avatar
Tim Dettmers committed
2186
        x = xq.float()
2187
2188
2189
2190
        if len(S1.shape) == 3 and len(x.shape) == 2:
            S1 = S1.squeeze(0)
        if len(S2.shape) == 3 and len(x.shape) == 2:
            S2 = S2.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
2191
        if len(S1.shape) == 2:
2192
            x *= S1 * S2 / (C * C)
Tim Dettmers's avatar
Tim Dettmers committed
2193
        else:
2194
            x *= S1 * S2 / (C * C)
Tim Dettmers's avatar
Tim Dettmers committed
2195
        return x.to(dtype)
2196
    elif quant_type in ["truncated-vector", "vector"]:
Tim Dettmers's avatar
Tim Dettmers committed
2197
        x = xq.float()
2198
2199
2200
2201
        if len(S1.shape) == 3 and len(x.shape) == 2:
            S1 = S1.squeeze(0)
        if len(S2.shape) == 3 and len(x.shape) == 2:
            S2 = S2.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
2202
        if len(S1.shape) == 2:
2203
            x *= S1 / C
Tim Dettmers's avatar
Tim Dettmers committed
2204
        else:
2205
2206
            x *= S1 / C
        x *= S2 / C
Tim Dettmers's avatar
Tim Dettmers committed
2207
        return x.to(dtype)
2208
2209
    else:
        return None
Tim Dettmers's avatar
Tim Dettmers committed
2210
2211
2212


def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
2213
    offset = B.float().t().sum(0) * (SA[0] + SA[1])
Tim Dettmers's avatar
Tim Dettmers committed
2214
    x = xq.float()
2215
2216
    if len(xq.shape) == 2 and len(SB.shape) == 3:
        SB = SB.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
2217
    if len(SB.shape) == 2:
2218
        x *= SB.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
2219
    else:
2220
2221
2222
        x *= SB / 127
    x *= SA[1] / 127
    x += offset
Tim Dettmers's avatar
Tim Dettmers committed
2223
    return x.to(dtype)
2224

2225

2226
2227
2228
def extract_outliers(A, SA, idx):
    shapeA = SA[0]
    formatA = SA[1]
2229
2230
    assert formatA in ["col_turing", "col_ampere"]
    assert A.device.type == "cuda"
2231

2232
2233
2234
    out = torch.zeros(
        (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device
    )
2235
2236
2237
2238
2239
2240
2241
2242

    idx_size = ct.c_int32(idx.numel())
    rows = ct.c_int32(shapeA[0])
    cols = ct.c_int32(shapeA[1])
    ptrA = get_ptr(A)
    ptrIdx = get_ptr(idx)
    ptrOut = get_ptr(out)

2243
    prev_device = pre_call(A.device)
2244
2245
    if formatA == 'col_turing':
        lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
2246
    elif formatA == "col_ampere":
2247
        lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
2248
    post_call(prev_device)
2249
2250

    return out