functional.py 77.3 KB
Newer Older
1
2
3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
Tim Dettmers's avatar
Tim Dettmers committed
4
# LICENSE file in the root directory of this source tree.
5
import ctypes as ct
Tom Aarsen's avatar
Tom Aarsen committed
6
import itertools
7
import operator
Tim Dettmers's avatar
Tim Dettmers committed
8
9
import random
import torch
Tim Dettmers's avatar
Tim Dettmers committed
10
import itertools
Tim Dettmers's avatar
Tim Dettmers committed
11
import math
Tim Dettmers's avatar
Tim Dettmers committed
12
from scipy.stats import norm
Tim Dettmers's avatar
Tim Dettmers committed
13
import numpy as np
14

Tom Aarsen's avatar
Tom Aarsen committed
15
from functools import reduce  # Required in Python 3
16
from typing import Tuple
Tim Dettmers's avatar
Tim Dettmers committed
17
18
from torch import Tensor

19
from .cextension import COMPILED_WITH_CUDA, lib
Tom Aarsen's avatar
Tom Aarsen committed
20

21
22
23
24

# math.prod not compatible with python < 3.8
def prod(iterable):
    return reduce(operator.mul, iterable, 1)
Max Ryabinin's avatar
Max Ryabinin committed
25

Tim Dettmers's avatar
Tim Dettmers committed
26
27
name2qmap = {}

Max Ryabinin's avatar
Max Ryabinin committed
28
if COMPILED_WITH_CUDA:
29
    """C FUNCTIONS FOR OPTIMIZERS"""
Max Ryabinin's avatar
Max Ryabinin committed
30
    str2optimizer32bit = {}
31
    str2optimizer32bit["adam"] = (lib.cadam32bit_gfp32, lib.cadam32bit_gfp16, lib.cadam32bit_gbf16)
32
33
34
35
36
37
38
39
40
41
42
43
    str2optimizer32bit["momentum"] = (
        lib.cmomentum32bit_g32,
        lib.cmomentum32bit_g16,
    )
    str2optimizer32bit["rmsprop"] = (
        lib.crmsprop32bit_g32,
        lib.crmsprop32bit_g16,
    )
    str2optimizer32bit["adagrad"] = (
        lib.cadagrad32bit_g32,
        lib.cadagrad32bit_g16,
    )
Max Ryabinin's avatar
Max Ryabinin committed
44
45

    str2optimizer8bit = {}
46
47
48
49
    str2optimizer8bit["adam"] = (
        lib.cadam_static_8bit_g32,
        lib.cadam_static_8bit_g16,
    )
50
51
52
53
54
55
56
57
    str2optimizer8bit["momentum"] = (
        lib.cmomentum_static_8bit_g32,
        lib.cmomentum_static_8bit_g16,
    )
    str2optimizer8bit["rmsprop"] = (
        lib.crmsprop_static_8bit_g32,
        lib.crmsprop_static_8bit_g16,
    )
58
59
60
61
    str2optimizer8bit["lamb"] = (
        lib.cadam_static_8bit_g32,
        lib.cadam_static_8bit_g16,
    )
62
63
64
65
    str2optimizer8bit["lars"] = (
        lib.cmomentum_static_8bit_g32,
        lib.cmomentum_static_8bit_g16,
    )
Max Ryabinin's avatar
Max Ryabinin committed
66
67

    str2optimizer8bit_blockwise = {}
68
69
70
    str2optimizer8bit_blockwise["adam"] = (
        lib.cadam_8bit_blockwise_fp32,
        lib.cadam_8bit_blockwise_fp16,
Tim Dettmers's avatar
Tim Dettmers committed
71
        lib.cadam_8bit_blockwise_bf16,
72
73
74
75
76
77
78
79
80
81
82
83
84
    )
    str2optimizer8bit_blockwise["momentum"] = (
        lib.cmomentum_8bit_blockwise_fp32,
        lib.cmomentum_8bit_blockwise_fp16,
    )
    str2optimizer8bit_blockwise["rmsprop"] = (
        lib.crmsprop_8bit_blockwise_fp32,
        lib.crmsprop_8bit_blockwise_fp16,
    )
    str2optimizer8bit_blockwise["adagrad"] = (
        lib.cadagrad_8bit_blockwise_fp32,
        lib.cadagrad_8bit_blockwise_fp16,
    )
Tim Dettmers's avatar
Tim Dettmers committed
85
86


87
class CUBLAS_Context:
Tim Dettmers's avatar
Tim Dettmers committed
88
89
90
    _instance = None

    def __init__(self):
91
        raise RuntimeError("Call get_instance() instead")
Tim Dettmers's avatar
Tim Dettmers committed
92
93
94

    def initialize(self):
        self.context = {}
95
96
        # prev_device = torch.cuda.current_device()
        # for i in range(torch.cuda.device_count()):
Tim Dettmers's avatar
Tim Dettmers committed
97
98
        #    torch.cuda.set_device(torch.device('cuda', i))
        #    self.context.append(ct.c_void_p(lib.get_context()))
99
        # torch.cuda.set_device(prev_device)
Tim Dettmers's avatar
Tim Dettmers committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            cls._instance = cls.__new__(cls)
            cls._instance.initialize()
        return cls._instance

    def get_context(self, device):
        if device.index not in self.context:
            prev_device = torch.cuda.current_device()
            torch.cuda.set_device(device)
            self.context[device.index] = ct.c_void_p(lib.get_context())
            torch.cuda.set_device(prev_device)
        return self.context[device.index]

116

117
class Cusparse_Context:
Tim Dettmers's avatar
Tim Dettmers committed
118
119
120
    _instance = None

    def __init__(self):
121
        raise RuntimeError("Call get_instance() instead")
Tim Dettmers's avatar
Tim Dettmers committed
122
123
124
125
126
127
128
129
130
131

    def initialize(self):
        self.context = ct.c_void_p(lib.get_cusparse())

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            cls._instance = cls.__new__(cls)
            cls._instance.initialize()
        return cls._instance
Tim Dettmers's avatar
Tim Dettmers committed
132

133

Tim Dettmers's avatar
Tim Dettmers committed
134
def create_linear_map(signed=True, total_bits=8, add_zero=True):
135
    sign = (-1.0 if signed else 0.0)
Tim Dettmers's avatar
Tim Dettmers committed
136
137
138
139
140
141
142
143
144
    total_values = 2**total_bits
    if add_zero or total_bits < 8:
        # add a zero
        # since we simulate less bits by having zeros in the data type, we
        # we need to center the quantization around zero and as such lose
        # a single value
        total_values = (2**total_bits if not signed else 2**total_bits-1)

    values = torch.linspace(sign, 1.0, total_values)
145
146
147
    gap = 256 - values.numel()
    if gap == 0:
        return values
Tim Dettmers's avatar
Tim Dettmers committed
148
    else:
149
150
151
        l = values.numel()//2
        #return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist())
        return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
Tim Dettmers's avatar
Tim Dettmers committed
152

153
def create_custom_map(seed=0, scale=0.01):
Tim Dettmers's avatar
Tim Dettmers committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    v = [12, 10, 8, 6, 3, 2, 1]
    # 16-bit 7B 22.33, 4-bit best 22.88, FP4 23.25, 4-bit 95 22.97, 4-bit evo 22.45
    # 16-bit 13B 70.35, 4-bit best 67.16, FP4 100.78, 4-bit-95 69.39, 4-bit evo 70.48

    # 13B 100 steps:
    # - 4-bit evo: 86.02
    # - 4-bit norm: 78.73
    # - 4-bit FP4:
    # - 16-bit:

    # interval search on normal distribution
    #v = [3.090232306167813, 1.4589770349449647, 1.064410327932115, 0.7896806653244509, 0.5646884166925807, 0.3653406435875121, 0.17964844284441311] # 0.999 26.5
    #v = [2.3263478740408408, 1.4050715603096329, 1.0364333894937898, 0.7721932141886848, 0.5533847195556727, 0.3584587932511938, 0.1763741647808615] # 0.99 24.99
    #v = [1.6448536269514722, 1.2040469600267016, 0.9208229763683788, 0.6971414348463417, 0.5039653672113453, 0.3280721075316511, 0.16184416680396213] # 0.95 24.53 22.97
    #v = [1.4050715603096329, 1.0803193408149558, 0.8416212335729143, 0.643345405392917, 0.4676987991145084, 0.3054807880993974, 0.1509692154967774] # 0.92 24.81
    #v = [1.2815515655446004, 1.0062699858608395, 0.7916386077433746, 0.6084981344998837, 0.4438613119262478, 0.29050677112339396, 0.14372923370582416] # 0.9 24.68
    #v = [1.8807936081512509, 1.2980047163986055, 0.9769954022693226, 0.7341502955472268, 0.5285136765472481, 0.343225833559403, 0.16910470304375366] # 0.97 25.03
    #v = [1.7506860712521692, 1.2496468758017434, 0.9485350408266378, 0.7155233557034365, 0.5162006366043174, 0.3356393360829622, 0.16547334454641704] # 0.96 24.85 23.01
    #v = [1.5547735945968535, 1.1608220210715001, 0.893800631179489, 0.6789921163940618, 0.4918050830048072, 0.3205236191093902, 0.15821711945563585] # 0.94 24.47
    #v = [1.475791028179171, 1.1196635980209986, 0.8674156943957149, 0.6610637542614526, 0.4797170937629045, 0.31299335020578195, 0.15459215234139795] # 0.93 24.85
    #v = [1.5981931399228175, 1.1821583959486879, 0.9072289939325966, 0.6880384454306778, 0.49787602226482025, 0.3242955535308664, 0.160030379970179] # 0.945 24.287
    ##v = [1.6164363711150211, 1.1908453913294612, 0.9126463450304729, 0.6916727602238111, 0.5003095327012462, 0.3258056171348078, 0.1607558311941979] # 0.947 24.293
    #v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.207
    #v = [1.6118251211466303, 1.188665228776879, 0.9112895004060624, 0.690763326564427, 0.4997008778346997, 0.3254280317127771, 0.16057446047146948] # 0.9465 24.30
    #v = [1.6027040905517569, 1.184321770169049, 0.9085808314549837, 0.6889461706317986, 0.4984841229538408, 0.32467299997597887, 0.1602117348657326] # 0.9455 24.293
Tim Dettmers's avatar
Tim Dettmers committed
179
    #v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.37 22.88
Tim Dettmers's avatar
Tim Dettmers committed
180
181
182
183
184
185
186
187
188

    # 7B evo start 
    #v = [1.62129629, 1.18870191, 0.90848106, 0.69108646, 0.50515268, 0.34927819905,  0.14122701] # 22.06
    #v = [1.6143079205628337, 1.1888081407660314, 0.8990131955745421, 0.694373759813679, 0.5083033257326773, 0.3452499746844963, 0.1148939728228951]      
    #v = [1.614442766030303, 1.189401918639665, 0.8998038168964273, 0.6953094818279475, 0.5073264599048384, 0.3449003790823619, 0.11428378427205564]

    # 13B evo start
    #v = [1.6077535089716468, 1.1914902148179205, 0.8999752421085561, 0.6967904489387543, 0.4949093928311768, 0.30920472033044544, 0.15391602735952042]
    #v = [1.586363722436466, 1.202610827188916, 0.9003332576346587, 0.6904888715206972, 0.49490974688233724, 0.2971151461329376, 0.15683230810738283]
Tim Dettmers's avatar
Tim Dettmers committed
189
    v = [1.5842247437829478, 1.2037228884260156, 0.900369059187269, 0.6898587137788914, 0.4949097822874533, 0.2959061887131868, 0.15712393618216908]
Tim Dettmers's avatar
Tim Dettmers committed
190
191
192
193
194

    # mean evo 7B + 13B
    #v = [1.5993337549066253, 1.1965624035328402, 0.9000864380418481, 0.6925840978034195, 0.5011181210961458, 0.32040328389777434, 0.13570386022711237]

    # theoretically optiomal (0.93333)
Tim Dettmers's avatar
Tim Dettmers committed
195
    #v = [1.501085946044025, 1.1331700302595604, 0.8761428492468408, 0.6670160135425023, 0.48373855304610314, 0.3155014472579608, 0.15580024666388428] # 0.9333333333333333
Tim Dettmers's avatar
Tim Dettmers committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214

    if seed > 0:
        v = np.array(v)
        np.random.seed(seed)
        v += np.random.randn(7)*scale
        print(v.tolist())
        #v[0] +=  (np.random.randn(1)*0.001)[0]
        #v[-1] +=  (np.random.randn(1)*0.001)[0]
    #print(v[0], v[-1])
        v = v.tolist()
    values = v + [0]*(256-14) +  \
             v[::-1]

    values = torch.Tensor(values)
    values[0:7] *= -1
    values = values.sort().values
    values /= values.max()
    assert values.numel() == 256
    return values
215

216
def create_normal_map(offset=0.9677083, use_extra_value=True):
Tim Dettmers's avatar
Tim Dettmers committed
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235

    if use_extra_value:
        # one more positive value, this is an asymmetric type
        v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist()
        v2 = [0]*(256-15) ## we have 15 non-zero values in this data type
        v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
        v = v1 + v2 + v3
    else:
        v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist()
        v2 = [0]*(256-14) ## we have 14 non-zero values in this data type
        v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
        v = v1 + v2 + v3

    values = torch.Tensor(v)
    values = values.sort().values
    values /= values.max()
    assert values.numel() == 256
    return values

Tim Dettmers's avatar
Tim Dettmers committed
236
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
Tim Dettmers's avatar
Tim Dettmers committed
237
238
    e = exponent_bits
    p = precision_bits
Tim Dettmers's avatar
Tim Dettmers committed
239
240
    has_sign = 1 if signed else 0
    assert e+p == total_bits-has_sign
Tim Dettmers's avatar
Tim Dettmers committed
241
242
243
    # the exponent is biased to 2^(e-1) -1 == 0
    evalues = []
    pvalues = []
Tim Dettmers's avatar
Tim Dettmers committed
244
    for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)):
Tim Dettmers's avatar
Tim Dettmers committed
245
246
247
248
        evalues.append(2**val)


    values = []
Tim Dettmers's avatar
Tim Dettmers committed
249
250
    lst = list(itertools.product([0, 1], repeat=precision_bits))
    #for ev in evalues:
Tim Dettmers's avatar
Tim Dettmers committed
251
    bias = 2**(exponent_bits-1)-1
Tim Dettmers's avatar
Tim Dettmers committed
252
253
254
255
256
257
258
    for evalue in range(2**(exponent_bits)):
        for bit_pattern in lst:
            value = (1 if evalue != 0 else 0)
            for i, pval in enumerate(list(bit_pattern)):
                value += pval*(2**-(i+1))
            if evalue == 0:
                # subnormals
Tim Dettmers's avatar
Tim Dettmers committed
259
                value = value*2**-(bias-1)
Tim Dettmers's avatar
Tim Dettmers committed
260
261
            else:
                # normals
Tim Dettmers's avatar
Tim Dettmers committed
262
                value = value*2**-(evalue-bias-2)
Tim Dettmers's avatar
Tim Dettmers committed
263
            values.append(value)
Tim Dettmers's avatar
Tim Dettmers committed
264
            if signed:
Tim Dettmers's avatar
Tim Dettmers committed
265
266
267
268
269
                values.append(-value)


    assert len(values) == 2**total_bits
    values.sort()
Tim Dettmers's avatar
Tim Dettmers committed
270
271
272
273
    if total_bits < 8:
        gap = 256 - len(values)
        for i in range(gap):
            values.append(0)
Tim Dettmers's avatar
Tim Dettmers committed
274
275
    values.sort()
    code = torch.Tensor(values)
276
    code /= code.max()
Tim Dettmers's avatar
Tim Dettmers committed
277
278
279
280
281

    return code



Tim Dettmers's avatar
Tim Dettmers committed
282
def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
283
    """
Tim Dettmers's avatar
Tim Dettmers committed
284
285
286
287
288
289
290
291
292
293
294
295
296
    Creates the dynamic quantiztion map.

    The dynamic data type is made up of a dynamic exponent and
    fraction. As the exponent increase from 0 to -7 the number
    of bits available for the fraction shrinks.

    This is a generalization of the dynamic type where a certain
    number of the bits and be reserved for the linear quantization
    region (the fraction). n determines the maximum number of
    exponent bits.

    For more details see
    (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]
297
    """
Tim Dettmers's avatar
Tim Dettmers committed
298
299
300
301
302

    data = []
    # these are additional items that come from the case
    # where all the exponent bits are zero and no
    # indicator bit is present
Tim Dettmers's avatar
Tim Dettmers committed
303
304
    non_sign_bits = total_bits - (1 if signed else 0)
    additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
305
306
    if not signed:
        additional_items = 2 * additional_items
Tim Dettmers's avatar
Tim Dettmers committed
307
308
    for i in range(max_exponent_bits):
        fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1))
Tim Dettmers's avatar
Tim Dettmers committed
309
        boundaries = torch.linspace(0.1, 1, fraction_items)
310
        means = (boundaries[:-1] + boundaries[1:]) / 2.0
Tim Dettmers's avatar
Tim Dettmers committed
311
        data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
312
        if signed:
Tim Dettmers's avatar
Tim Dettmers committed
313
            data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
314

Tim Dettmers's avatar
Tim Dettmers committed
315
316
317
318
319
320
        if additional_items > 0:
            boundaries = torch.linspace(0.1, 1, additional_items + 1)
            means = (boundaries[:-1] + boundaries[1:]) / 2.0
            data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
            if signed:
                data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
321
322
323

    data.append(0)
    data.append(1.0)
Tim Dettmers's avatar
Tim Dettmers committed
324
325
326
327
328

    gap = 256 - len(data)
    for i in range(gap):
        data.append(0)

Tim Dettmers's avatar
Tim Dettmers committed
329
330
331
    data.sort()
    return Tensor(data)

Tim Dettmers's avatar
Tim Dettmers committed
332
333
334
335
336
337
338
339
340
341
342
343
344
345
def create_quantile_map(A, total_bits=8):
    q = estimate_quantiles(A, num_quantiles=2**total_bits-1)
    q = q.tolist()
    q.append(0)

    gap = 256 - len(q)
    for i in range(gap):
        q.append(0)

    q.sort()

    q = Tensor(q)
    q = q/q.abs().max()
    return q
346

Tim Dettmers's avatar
Tim Dettmers committed
347
def get_special_format_str():
348
    if not torch.cuda.is_available(): return 'col_turing'
Tom Aarsen's avatar
Tom Aarsen committed
349
    major, _minor = torch.cuda.get_device_capability()
350
    if major <= 7:
351
        return "col_turing"
Tom Aarsen's avatar
Tom Aarsen committed
352
    if major == 8:
353
        return "col_ampere"
Tom Aarsen's avatar
Tom Aarsen committed
354
    return "col_turing"
355

Tim Dettmers's avatar
Tim Dettmers committed
356

357
358
359

def is_on_gpu(tensors):
    on_gpu = True
360
    gpu_ids = set()
361
362
363
    for t in tensors:
        if t is None: continue # NULL pointers are fine
        on_gpu &= t.device.type == 'cuda'
364
365
366
        gpu_ids.add(t.device.index)
    if len(gpu_ids) > 1:
        raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:{[(t.shape, t.device) for t in tensors]}')
367
368
    return on_gpu

Tim Dettmers's avatar
Tim Dettmers committed
369
def get_ptr(A: Tensor) -> ct.c_void_p:
370
    """
Tim Dettmers's avatar
Tim Dettmers committed
371
372
373
374
375
376
377
378
379
380
    Get the ctypes pointer from a PyTorch Tensor.

    Parameters
    ----------
    A : torch.tensor
        The PyTorch tensor.

    Returns
    -------
    ctypes.c_void_p
381
382
383
384
    """
    if A is None:
        return None
    else:
385
        return ct.c_void_p(A.data.data_ptr())
386

Tim Dettmers's avatar
Tim Dettmers committed
387

Tim Dettmers's avatar
Tim Dettmers committed
388
389
390
391
392
def pre_call(device):
    prev_device = torch.cuda.current_device()
    torch.cuda.set_device(device)
    return prev_device

393

Tim Dettmers's avatar
Tim Dettmers committed
394
395
396
def post_call(prev_device):
    torch.cuda.set_device(prev_device)

397

Tim Dettmers's avatar
Tim Dettmers committed
398
399
400
401
def get_transform_func(dtype, orderA, orderOut, transpose=False):
    name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}'
    if not hasattr(lib, name):
        print(name)
402
403
404
        raise ValueError(
            f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}"
        )
Tim Dettmers's avatar
Tim Dettmers committed
405
406
407
    else:
        return getattr(lib, name)

408
409
410
411
412

def get_transform_buffer(
    shape, dtype, device, to_order, from_order="row", transpose=False
):
    # init_func = torch.empty
Tim Dettmers's avatar
Tim Dettmers committed
413
414
415
416
417
418
    init_func = torch.zeros
    dims = len(shape)

    if dims == 2:
        rows = shape[0]
    elif dims == 3:
419
        rows = shape[0] * shape[1]
Tim Dettmers's avatar
Tim Dettmers committed
420
421
422
423
424
425
426
427
428
429
    cols = shape[-1]

    state = (shape, to_order)
    if transpose:
        # swap dims
        tmp = rows
        rows = cols
        cols = tmp
        state = (shape[::-1], to_order)

430
    if to_order == "row" or to_order == "col":
Tim Dettmers's avatar
Tim Dettmers committed
431
        return init_func(shape, dtype=dtype, device=device), state
432
    elif to_order == "col32":
Tim Dettmers's avatar
Tim Dettmers committed
433
        # blocks of 32 columns (padded)
434
        cols = 32 * ((cols + 31) // 32)
Tim Dettmers's avatar
Tim Dettmers committed
435
        return init_func((rows, cols), dtype=dtype, device=device), state
436
    elif to_order == "col_turing":
Tim Dettmers's avatar
Tim Dettmers committed
437
        # blocks of 32 columns and 8 rows
438
439
        cols = 32 * ((cols + 31) // 32)
        rows = 8 * ((rows + 7) // 8)
Tim Dettmers's avatar
Tim Dettmers committed
440
        return init_func((rows, cols), dtype=dtype, device=device), state
441
    elif to_order == "col_ampere":
Tim Dettmers's avatar
Tim Dettmers committed
442
        # blocks of 32 columns and 32 rows
443
444
        cols = 32 * ((cols + 31) // 32)
        rows = 32 * ((rows + 31) // 32)
Tim Dettmers's avatar
Tim Dettmers committed
445
446
        return init_func((rows, cols), dtype=dtype, device=device), state
    else:
447
448
        raise NotImplementedError(f"To_order not supported: {to_order}")

Tim Dettmers's avatar
Tim Dettmers committed
449

450
def nvidia_transform(
451
452
453
454
455
456
457
    A,
    to_order,
    from_order="row",
    out=None,
    transpose=False,
    state=None,
    ld=None,
458
459
460
461
462
463
464
465
466
467
468
):
    if state is None:
        state = (A.shape, from_order)
    else:
        from_order = state[1]
    if out is None:
        out, new_state = get_transform_buffer(
            state[0], A.dtype, A.device, to_order, state[1]
        )
    else:
        new_state = (state[1], to_order)
Tim Dettmers's avatar
Tim Dettmers committed
469
470
471
472
473
474
475
    func = get_transform_func(A.dtype, from_order, to_order, transpose)

    shape = state[0]
    if len(shape) == 2:
        dim1 = ct.c_int32(shape[0])
        dim2 = ct.c_int32(shape[1])
    elif ld is not None:
476
477
        n = prod(shape)
        dim1 = prod([shape[i] for i in ld])
478
        dim2 = ct.c_int32(n // dim1)
Tim Dettmers's avatar
Tim Dettmers committed
479
480
        dim1 = ct.c_int32(dim1)
    else:
481
        dim1 = ct.c_int32(shape[0] * shape[1])
Tim Dettmers's avatar
Tim Dettmers committed
482
483
484
485
486
487
488
        dim2 = ct.c_int32(shape[2])

    ptr = CUBLAS_Context.get_instance().get_context(A.device)
    func(ptr, get_ptr(A), get_ptr(out), dim1, dim2)

    return out, new_state

489

Tim Dettmers's avatar
Tim Dettmers committed
490
def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor:
Tim Dettmers's avatar
Tim Dettmers committed
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
    '''
    Estimates 256 equidistant quantiles on the input tensor eCDF.

    Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
    via the eCDF of the input tensor `A`. This is a fast but approximate algorithm
    and the extreme quantiles close to 0 and 1 have high variance / large estimation
    errors. These large errors can be avoided by using the offset variable which trims
    the distribution. The default offset value of 1/512 ensures minimum entropy encoding -- it
    trims 1/512 = 0.2% from each side of the distrivution. An offset value of 0.01 to 0.02
    usually has a much lower error but is not a minimum entropy encoding. Given an offset
    of 0.02 equidistance points in the range [0.02, 0.98] are used for the quantiles.

    Parameters
    ----------
    A : torch.Tensor
        The input tensor. Any shape.
    out : torch.Tensor
        Tensor with the 256 estimated quantiles.
    offset : float
Tim Dettmers's avatar
Tim Dettmers committed
510
511
512
        The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles)
    num_quantiles : int
        The number of equally spaced quantiles.
Tim Dettmers's avatar
Tim Dettmers committed
513
514
515
516
517
518

    Returns
    -------
    torch.Tensor:
        The 256 quantiles in float32 datatype.
    '''
Tim Dettmers's avatar
Tim Dettmers committed
519
520
521
522
523
524
    if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.')
    if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}")
    if num_quantiles < 256 and offset == 1/(512):
        # override default arguments
        offset = 1/(2*num_quantiles)

Tim Dettmers's avatar
Tim Dettmers committed
525
    if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
526
    is_on_gpu([A, out])
Tim Dettmers's avatar
Tim Dettmers committed
527
    device = pre_call(A.device)
Tim Dettmers's avatar
Tim Dettmers committed
528
    if A.dtype == torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
529
        lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
530
    elif A.dtype == torch.float16:
Tim Dettmers's avatar
Tim Dettmers committed
531
        lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
532
    else:
533
        raise NotImplementedError(f"Not supported data type {A.dtype}")
Tim Dettmers's avatar
Tim Dettmers committed
534
535
536
    post_call(device)

    if num_quantiles < 256:
Tim Dettmers's avatar
Tim Dettmers committed
537
        step = round(256/num_quantiles)
Tim Dettmers's avatar
Tim Dettmers committed
538
539
540
        idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
        out = out[idx]

Tim Dettmers's avatar
Tim Dettmers committed
541
542
    return out

543

544
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor:
545
    """
Tim Dettmers's avatar
Tim Dettmers committed
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
    Quantize tensor A in blocks of size 4096 values.

    Quantizes tensor A by dividing it into blocks of 4096 values.
    Then the absolute maximum value within these blocks is calculated
    for the non-linear quantization.

    Parameters
    ----------
    A : torch.Tensor
        The input tensor.
    code : torch.Tensor
        The quantization map.
    absmax : torch.Tensor
        The absmax values.
    rand : torch.Tensor
        The tensor for stochastic rounding.
    out : torch.Tensor
        The output tensor (8-bit).

    Returns
    -------
    torch.Tensor:
        The 8-bit tensor.
    tuple(torch.Tensor, torch.Tensor):
        The quantization state to undo the quantization.
571
    """
Tim Dettmers's avatar
Tim Dettmers committed
572

573

Tim Dettmers's avatar
Tim Dettmers committed
574
    if code is None:
575
576
577
        if "dynamic" not in name2qmap:
            name2qmap["dynamic"] = create_dynamic_map().to(A.device)
        code = name2qmap["dynamic"]
Tim Dettmers's avatar
Tim Dettmers committed
578
579
580

    if absmax is None:
        n = A.numel()
581
582
        blocks = n // blocksize
        blocks += 1 if n % blocksize > 0 else 0
Tim Dettmers's avatar
Tim Dettmers committed
583
584
        absmax = torch.zeros((blocks,), device=A.device)

585
586
    if out is None:
        out = torch.zeros_like(A, dtype=torch.uint8)
Tim Dettmers's avatar
Tim Dettmers committed
587
588

    if A.device.type != 'cpu':
589
        assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
590
        cblocksize = ct.c_int32(blocksize)
591
592
        prev_device = pre_call(A.device)
        code = code.to(A.device)
Tim Dettmers's avatar
Tim Dettmers committed
593
        if rand is not None:
594
            is_on_gpu([code, A, out, absmax, rand])
595
            assert blocksize==4096
Tim Dettmers's avatar
Tim Dettmers committed
596
597
598
            assert rand.numel() >= 1024
            rand_offset = random.randint(0, 1023)
            if A.dtype == torch.float32:
599
                lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
600
            elif A.dtype == torch.float16:
601
                lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
602
            else:
603
                raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
Tim Dettmers's avatar
Tim Dettmers committed
604
        else:
605
            is_on_gpu([code, A, out, absmax])
Tim Dettmers's avatar
Tim Dettmers committed
606
            if A.dtype == torch.float32:
607
                lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
608
            elif A.dtype == torch.float16:
609
                lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
610
            else:
611
                raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
612
        post_call(A.device)
Tim Dettmers's avatar
Tim Dettmers committed
613
614
    else:
        # cpu
615
        code = code.cpu()
Tim Dettmers's avatar
Tim Dettmers committed
616
        assert rand is None
617
        lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
618

619
620
621
622
623
624
625
626
627
    if nested:
        offset = absmax.mean()
        absmax -= offset
        qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False)
        state = [qabsmax, code, blocksize, nested, offset, state2]
    else:
        state = [absmax, code, blocksize, nested, None, None]


628
629

    return out, state
Tim Dettmers's avatar
Tim Dettmers committed
630

631
632
633
634
635
636
637
638

def dequantize_blockwise(
    A: Tensor,
    quant_state: Tuple[Tensor, Tensor] = None,
    absmax: Tensor = None,
    code: Tensor = None,
    out: Tensor = None,
    blocksize: int = 4096,
639
    nested=False
640
641
) -> Tensor:
    """
Tim Dettmers's avatar
Tim Dettmers committed
642
643
644
645
646
647
648
649
650
651
    Dequantizes blockwise quantized values.

    Dequantizes the tensor A with maximum absolute values absmax in
    blocks of size 4096.

    Parameters
    ----------
    A : torch.Tensor
        The input 8-bit tensor.
    quant_state : tuple(torch.Tensor, torch.Tensor)
652
        Tuple of code and absmax values.
Tim Dettmers's avatar
Tim Dettmers committed
653
654
655
656
657
658
659
660
661
662
663
664
    absmax : torch.Tensor
        The absmax values.
    code : torch.Tensor
        The quantization map.
    out : torch.Tensor
        Dequantized output tensor (default: float32)


    Returns
    -------
    torch.Tensor:
        Dequantized tensor (default: float32)
665
    """
Tim Dettmers's avatar
Tim Dettmers committed
666
667
    assert quant_state is not None or absmax is not None
    if code is None and quant_state is None:
668
669
670
        if "dynamic" not in name2qmap:
            name2qmap["dynamic"] = create_dynamic_map().to(A.device)
        code = name2qmap["dynamic"]
Tim Dettmers's avatar
Tim Dettmers committed
671

672
673
674
    if out is None:
        out = torch.zeros_like(A, dtype=torch.float32)
    if quant_state is None:
675
        quant_state = (absmax, code, blocksize)
676
    else:
677
678
679
680
        absmax, code, blocksize, nested, offset, state2 = quant_state
        if nested:
            absmax = dequantize_blockwise(absmax, state2)
            absmax += offset
Tim Dettmers's avatar
Tim Dettmers committed
681
682

    if A.device.type != 'cpu':
683
684
        device = pre_call(A.device)
        code = code.to(A.device)
685
        if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
686
            raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
687
        is_on_gpu([A, absmax, out])
Tim Dettmers's avatar
Tim Dettmers committed
688
        if out.dtype == torch.float32:
689
            lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
690
        elif out.dtype == torch.float16:
691
            lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
692
        else:
Tim Dettmers's avatar
Tim Dettmers committed
693
            raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
694
        post_call(A.device)
Tim Dettmers's avatar
Tim Dettmers committed
695
    else:
696
        code = code.cpu()
697
        lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
Tim Dettmers's avatar
Tim Dettmers committed
698
699
700

    return out

Tim Dettmers's avatar
Tim Dettmers committed
701
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
702
    return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4')
Tim Dettmers's avatar
Tim Dettmers committed
703

Tim Dettmers's avatar
Tim Dettmers committed
704
def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
705
    return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4')
Tim Dettmers's avatar
Tim Dettmers committed
706

707
def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor:
708
    """
709
    Quantize tensor A in blocks of 4-bit values.
710
711
712
713
714
715
716
717
718
719
720
721
722

    Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.

    Parameters
    ----------
    A : torch.Tensor
        The input tensor.
    absmax : torch.Tensor
        The absmax values.
    out : torch.Tensor
        The output tensor (8-bit).
    blocksize : int
        The blocksize used in quantization.
Tim Dettmers's avatar
Tim Dettmers committed
723
724
    quant_type : str
        The 4-bit quantization data type {fp4, nf4}
725
726
727
728
729

    Returns
    -------
    torch.Tensor:
        The 8-bit tensor with packed 4-bit values.
Tim Dettmers's avatar
Tim Dettmers committed
730
    tuple(torch.Tensor, torch.Size, torch.dtype, int):
731
732
733
734
        The quantization state to undo the quantization.
    """
    if A.device.type != 'cuda':
        raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}')
Tim Dettmers's avatar
Tim Dettmers committed
735
736
    if quant_type not in ['fp4', 'nf4']:
        raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.')
737
738
739
740
741
742
743
744
745
746
747

    n = A.numel()
    input_shape = A.shape

    if absmax is None:
        blocks = n // blocksize
        blocks += 1 if n % blocksize > 0 else 0
        absmax = torch.zeros((blocks,), device=A.device)


    if out is None:
Tim Dettmers's avatar
Tim Dettmers committed
748
        out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device)
749

750
    assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
751
752
753
754
755

    prev_device = pre_call(A.device)
    is_on_gpu([A, out, absmax])

    if A.dtype == torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
756
757
758
759
        if quant_type == 'fp4':
            lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
        else:
            lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
760
    elif A.dtype == torch.float16:
Tim Dettmers's avatar
Tim Dettmers committed
761
762
763
764
        if quant_type == 'fp4':
            lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
        else:
            lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
765
766
767
768
    else:
        raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
    post_call(A.device)

769
770
771
772
773
774
775
    if compress_statistics:
        offset = absmax.mean()
        absmax -= offset
        #code = create_custom_map().to(absmax.device)
        #qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256)
        qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
        del absmax
776
        state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type]
777
    else:
778
        state = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
779

780
781
    return out, state

Tim Dettmers's avatar
Tim Dettmers committed
782
def dequantize_fp4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
783
    return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4')
Tim Dettmers's avatar
Tim Dettmers committed
784
785

def dequantize_nf4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
786
    return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4')
787

788
def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor:
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
    """
    Dequantizes FP4 blockwise quantized values.

    Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize.

    Parameters
    ----------
    A : torch.Tensor
        The input 8-bit tensor (packed 4-bit values).
    quant_state : tuple(torch.Tensor, torch.Size, torch.dtype)
        Tuple of absmax values, original tensor shape and original dtype.
    absmax : torch.Tensor
        The absmax values.
    out : torch.Tensor
        Dequantized output tensor.
Tim Dettmers's avatar
Tim Dettmers committed
804
805
806
807
    blocksize : int
        The blocksize used in quantization.
    quant_type : str
        The 4-bit quantization data type {fp4, nf4}
808
809
810
811
812
813
814
815
816


    Returns
    -------
    torch.Tensor:
        Dequantized tensor.
    """
    if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
        raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
Tim Dettmers's avatar
Tim Dettmers committed
817
818
    if quant_type not in ['fp4', 'nf4']:
        raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.')
819
820
821
822
823
824

    if quant_state is None:
        assert absmax is not None and out is not None
        shape = out.shape
        dtype = out.dtype
    else:
825
826
        absmax, shape, dtype, blocksize, compressed_stats, quant_type = quant_state

827

828
829
830
831
    if compressed_stats is not None:
        offset, state2 = compressed_stats
        absmax = dequantize_blockwise(absmax, state2)
        absmax += offset
832
833
834
835
836
837

    if out is None:
        out = torch.empty(shape, dtype=dtype, device=A.device)

    n = out.numel()

Tim Dettmers's avatar
Tim Dettmers committed
838

839
840
841
    device = pre_call(A.device)
    is_on_gpu([A, absmax, out])
    if out.dtype == torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
842
843
844
845
        if quant_type == 'fp4':
            lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
        else:
            lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
846
    elif out.dtype == torch.float16:
Tim Dettmers's avatar
Tim Dettmers committed
847
848
849
850
        if quant_type == 'fp4':
            lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
        else:
            lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
851
852
853
854
    else:
        raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
    post_call(A.device)

Tim Dettmers's avatar
Tim Dettmers committed
855
856
857
    is_transposed = (True if A.shape[0] == 1 else False)
    if is_transposed: return out.t()
    else: return out
858
859


860
def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
Tim Dettmers's avatar
Tim Dettmers committed
861
    if code is None:
862
863
864
        if "dynamic" not in name2qmap:
            name2qmap["dynamic"] = create_dynamic_map().to(A.device)
        code = name2qmap["dynamic"]
Tim Dettmers's avatar
Tim Dettmers committed
865
866
867
        code = code.to(A.device)

    absmax = torch.abs(A).max()
868
    inp = A / absmax
Tim Dettmers's avatar
Tim Dettmers committed
869
870
871
    out = quantize_no_absmax(inp, code, out)
    return out, (absmax, code)

872
873
874
875
876
877
878
879

def dequantize(
    A: Tensor,
    quant_state: Tuple[Tensor, Tensor] = None,
    absmax: Tensor = None,
    code: Tensor = None,
    out: Tensor = None,
) -> Tensor:
Tim Dettmers's avatar
Tim Dettmers committed
880
881
    assert quant_state is not None or absmax is not None
    if code is None and quant_state is None:
882
883
884
        if "dynamic" not in name2qmap:
            name2qmap["dynamic"] = create_dynamic_map().to(A.device)
        code = name2qmap["dynamic"]
Tim Dettmers's avatar
Tim Dettmers committed
885
886
        code = code.to(A.device)

887
888
    if quant_state is None:
        quant_state = (absmax, code)
Tim Dettmers's avatar
Tim Dettmers committed
889
    out = dequantize_no_absmax(A, quant_state[1], out)
890
    return out * quant_state[0]
Tim Dettmers's avatar
Tim Dettmers committed
891

892
893

def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
Tim Dettmers's avatar
Tim Dettmers committed
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
    '''
    Quantizes input tensor to 8-bit.

    Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
    `out` using the quantization map `code`.

    Parameters
    ----------
    A : torch.Tensor
        The input tensor.
    code : torch.Tensor
        The quantization map.
    out : torch.Tensor, optional
        The output tensor. Needs to be of type byte.

    Returns
    -------
    torch.Tensor:
        Quantized 8-bit tensor.
    '''
    if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
915
    is_on_gpu([A, out])
Tim Dettmers's avatar
Tim Dettmers committed
916
917
918
    lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
    return out

919
920

def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
Tim Dettmers's avatar
Tim Dettmers committed
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
    '''
    Dequantizes the 8-bit tensor to 32-bit.

    Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
    the quantization map `code`.

    Parameters
    ----------
    A : torch.Tensor
        The 8-bit input tensor.
    code : torch.Tensor
        The quantization map.
    out : torch.Tensor
        The 32-bit output tensor.

    Returns
    -------
    torch.Tensor:
        32-bit output tensor.
    '''
    if out is None: out = torch.zeros_like(A, dtype=torch.float32)
942
    is_on_gpu([code, A, out])
Tim Dettmers's avatar
Tim Dettmers committed
943
944
945
    lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
    return out

946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964

def optimizer_update_32bit(
    optimizer_name: str,
    g: Tensor,
    p: Tensor,
    state1: Tensor,
    beta1: float,
    eps: float,
    step: int,
    lr: float,
    state2: Tensor = None,
    beta2: float = 0.0,
    weight_decay: float = 0.0,
    gnorm_scale: float = 1.0,
    unorm_vec: Tensor = None,
    max_unorm: float = 0.0,
    skip_zeros=False,
) -> None:
    """
Tim Dettmers's avatar
Tim Dettmers committed
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
    Performs an inplace optimizer update with one or two optimizer states.

    Universal optimizer update for 32-bit state and 32/16-bit gradients/weights.

    Parameters
    ----------
    optimizer_name : str
        The name of the optimizer: {adam}.
    g : torch.Tensor
        Gradient tensor.
    p : torch.Tensor
        Parameter tensor.
    state1 : torch.Tensor
        Optimizer state 1.
    beta1 : float
        Optimizer beta1.
    eps : float
        Optimizer epsilon.
    weight_decay : float
        Weight decay.
    step : int
        Current optimizer step.
    lr : float
        The learning rate.
    state2 : torch.Tensor
        Optimizer state 2.
    beta2 : float
        Optimizer beta2.
    gnorm_scale : float
        The factor to rescale the gradient to the max clip value.
995
996
997
998
999
1000
    unorm_vec : torch.Tensor
        The tensor for the update norm.
    max_unorm : float
        The maximum update norm relative to the weight norm.
    skip_zeros : bool
        Whether to skip zero-valued gradients or not (default: False).
1001
    """
Tim Dettmers's avatar
Tim Dettmers committed
1002
1003
1004
1005
1006
1007

    param_norm = 0.0
    if max_unorm > 0.0:
        param_norm = torch.norm(p.data.float())


1008
1009
1010
1011
1012
1013
1014
    optim_func = None
    if g.dtype == torch.float32:
        optim_func = str2optimizer32bit[optimizer_name][0]
    elif g.dtype == torch.float16:
        optim_func = str2optimizer32bit[optimizer_name][1]
    elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3):
        optim_func = str2optimizer32bit[optimizer_name][2]
Tim Dettmers's avatar
Tim Dettmers committed
1015
    else:
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
        raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}")

    is_on_gpu([g, p, state1, state2, unorm_vec])
    prev_device = pre_call(g.device)
    optim_func(
        get_ptr(g),
        get_ptr(p),
        get_ptr(state1),
        get_ptr(state2),
        get_ptr(unorm_vec),
        ct.c_float(max_unorm),
        ct.c_float(param_norm),
        ct.c_float(beta1),
        ct.c_float(beta2),
        ct.c_float(eps),
        ct.c_float(weight_decay),
        ct.c_int32(step),
        ct.c_float(lr),
        ct.c_float(gnorm_scale),
        ct.c_bool(skip_zeros),
        ct.c_int32(g.numel()))
    post_call(prev_device)
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062


def optimizer_update_8bit(
    optimizer_name: str,
    g: Tensor,
    p: Tensor,
    state1: Tensor,
    state2: Tensor,
    beta1: float,
    beta2: float,
    eps: float,
    step: int,
    lr: float,
    qmap1: Tensor,
    qmap2: Tensor,
    max1: Tensor,
    max2: Tensor,
    new_max1: Tensor,
    new_max2: Tensor,
    weight_decay: float = 0.0,
    gnorm_scale: float = 1.0,
    unorm_vec: Tensor = None,
    max_unorm: float = 0.0,
) -> None:
    """
Tim Dettmers's avatar
Tim Dettmers committed
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
    Performs an inplace Adam update.

    Universal Adam update for 32/8-bit state and 32/16-bit gradients/weights.
    Uses AdamW formulation if weight decay > 0.0.

    Parameters
    ----------
    optimizer_name : str
        The name of the optimizer. Choices {adam, momentum}
    g : torch.Tensor
        Gradient tensor.
    p : torch.Tensor
        Parameter tensor.
    state1 : torch.Tensor
        Adam state 1.
    state2 : torch.Tensor
        Adam state 2.
    beta1 : float
        Adam beta1.
    beta2 : float
        Adam beta2.
    eps : float
        Adam epsilon.
    weight_decay : float
        Weight decay.
    step : int
        Current optimizer step.
    lr : float
        The learning rate.
    qmap1 : torch.Tensor
        Quantization map for first Adam state.
    qmap2 : torch.Tensor
        Quantization map for second Adam state.
    max1 : torch.Tensor
        Max value for first Adam state update.
    max2 : torch.Tensor
        Max value for second Adam state update.
    new_max1 : torch.Tensor
        Max value for the next Adam update of the first state.
    new_max2 : torch.Tensor
        Max value for the next Adam update of the second state.
    gnorm_scale : float
        The factor to rescale the gradient to the max clip value.
1106
1107
1108
1109
    unorm_vec : torch.Tensor
        The tensor for the update norm.
    max_unorm : float
        The maximum update norm relative to the weight norm.
1110
    """
Tim Dettmers's avatar
Tim Dettmers committed
1111
1112
1113
1114
1115
1116

    param_norm = 0.0
    if max_unorm > 0.0:
        param_norm = torch.norm(p.data.float())

    if g.dtype == torch.float32 and state1.dtype == torch.uint8:
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
        str2optimizer8bit[optimizer_name][0](
            get_ptr(p),
            get_ptr(g),
            get_ptr(state1),
            get_ptr(state2),
            get_ptr(unorm_vec),
            ct.c_float(max_unorm),
            ct.c_float(param_norm),
            ct.c_float(beta1),
            ct.c_float(beta2),
            ct.c_float(eps),
            ct.c_int32(step),
            ct.c_float(lr),
            get_ptr(qmap1),
            get_ptr(qmap2),
            get_ptr(max1),
            get_ptr(max2),
            get_ptr(new_max1),
            get_ptr(new_max2),
            ct.c_float(weight_decay),
            ct.c_float(gnorm_scale),
            ct.c_int32(g.numel()),
        )
Tim Dettmers's avatar
Tim Dettmers committed
1140
    elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
        str2optimizer8bit[optimizer_name][1](
            get_ptr(p),
            get_ptr(g),
            get_ptr(state1),
            get_ptr(state2),
            get_ptr(unorm_vec),
            ct.c_float(max_unorm),
            ct.c_float(param_norm),
            ct.c_float(beta1),
            ct.c_float(beta2),
            ct.c_float(eps),
            ct.c_int32(step),
            ct.c_float(lr),
            get_ptr(qmap1),
            get_ptr(qmap2),
            get_ptr(max1),
            get_ptr(max2),
            get_ptr(new_max1),
            get_ptr(new_max2),
            ct.c_float(weight_decay),
            ct.c_float(gnorm_scale),
            ct.c_int32(g.numel()),
        )
Tim Dettmers's avatar
Tim Dettmers committed
1164
    else:
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
        raise ValueError(
            f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
        )


def optimizer_update_8bit_blockwise(
    optimizer_name: str,
    g: Tensor,
    p: Tensor,
    state1: Tensor,
    state2: Tensor,
    beta1: float,
    beta2: float,
    eps: float,
    step: int,
    lr: float,
    qmap1: Tensor,
    qmap2: Tensor,
    absmax1: Tensor,
    absmax2: Tensor,
    weight_decay: float = 0.0,
    gnorm_scale: float = 1.0,
    skip_zeros=False,
) -> None:
Tim Dettmers's avatar
Tim Dettmers committed
1189

Tim Dettmers's avatar
Tim Dettmers committed
1190
    optim_func = None
Tim Dettmers's avatar
Tim Dettmers committed
1191
    if g.dtype == torch.float32 and state1.dtype == torch.uint8:
1192
        optim_func = str2optimizer8bit_blockwise[optimizer_name][0]
Tim Dettmers's avatar
Tim Dettmers committed
1193
    elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
1194
        optim_func = str2optimizer8bit_blockwise[optimizer_name][1]
Tim Dettmers's avatar
Tim Dettmers committed
1195
1196
    elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and
          len(str2optimizer8bit_blockwise[optimizer_name])==3):
1197
        optim_func = str2optimizer8bit_blockwise[optimizer_name][2]
Tim Dettmers's avatar
Tim Dettmers committed
1198
    else:
1199
1200
1201
        raise ValueError(
            f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
        )
Tim Dettmers's avatar
Tim Dettmers committed
1202

Tim Dettmers's avatar
Tim Dettmers committed
1203
1204
1205
    is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2])

    prev_device = pre_call(g.device)
1206
    optim_func(
Tim Dettmers's avatar
Tim Dettmers committed
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
        get_ptr(p),
        get_ptr(g),
        get_ptr(state1),
        get_ptr(state2),
        ct.c_float(beta1),
        ct.c_float(beta2),
        ct.c_float(eps),
        ct.c_int32(step),
        ct.c_float(lr),
        get_ptr(qmap1),
        get_ptr(qmap2),
        get_ptr(absmax1),
        get_ptr(absmax2),
        ct.c_float(weight_decay),
        ct.c_float(gnorm_scale),
        ct.c_bool(skip_zeros),
        ct.c_int32(g.numel()),
    )
    post_call(prev_device)
Tim Dettmers's avatar
Tim Dettmers committed
1226

1227
1228
1229
def percentile_clipping(
    grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5
):
Tim Dettmers's avatar
Tim Dettmers committed
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
    """Applies percentile clipping

    grad: torch.Tensor
        The gradient tensor.
    gnorm_vec: torch.Tensor
        Vector of gradient norms. 100 elements expected.
    step: int
        The current optimiation steps (number of past gradient norms).

    """
1240
    is_on_gpu([grad, gnorm_vec])
Tim Dettmers's avatar
Tim Dettmers committed
1241
    if grad.dtype == torch.float32:
1242
1243
1244
1245
1246
1247
        lib.cpercentile_clipping_g32(
            get_ptr(grad),
            get_ptr(gnorm_vec),
            ct.c_int32(step),
            ct.c_int32(grad.numel()),
        )
Tim Dettmers's avatar
Tim Dettmers committed
1248
    elif grad.dtype == torch.float16:
1249
1250
1251
1252
1253
1254
        lib.cpercentile_clipping_g16(
            get_ptr(grad),
            get_ptr(gnorm_vec),
            ct.c_int32(step),
            ct.c_int32(grad.numel()),
        )
Tim Dettmers's avatar
Tim Dettmers committed
1255
    else:
1256
        raise ValueError(f"Gradient type {grad.dtype} not supported!")
Tim Dettmers's avatar
Tim Dettmers committed
1257
1258
1259
1260
1261
1262
1263

    current_gnorm = torch.sqrt(gnorm_vec[step % 100])
    vals, idx = torch.sort(gnorm_vec)
    clip_value = torch.sqrt(vals[percentile])
    gnorm_scale = 1.0

    if current_gnorm > clip_value:
1264
        gnorm_scale = clip_value / current_gnorm
Tim Dettmers's avatar
Tim Dettmers committed
1265
1266
1267
1268

    return current_gnorm, clip_value, gnorm_scale


1269
1270
1271
def histogram_scatter_add_2d(
    histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor
):
Tim Dettmers's avatar
Tim Dettmers committed
1272
1273
1274
1275
1276
1277
    assert len(histogram.shape) == 2
    assert histogram.dtype == torch.float32
    assert source.dtype == torch.float32
    assert index1.dtype == torch.int32
    assert index2.dtype == torch.int32

1278
1279
1280
1281
    assert histogram.device.type == "cuda"
    assert index1.device.type == "cuda"
    assert index2.device.type == "cuda"
    assert source.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
1282
1283
1284

    maxdim1 = ct.c_int32(histogram.shape[0])
    n = ct.c_int32(index1.numel())
1285
    is_on_gpu([histogram, index1, index2, source])
Tim Dettmers's avatar
Tim Dettmers committed
1286
    lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
1287

Tim Dettmers's avatar
Tim Dettmers committed
1288
1289
1290
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
    if not torch.cuda.is_initialized(): torch.cuda.init()
    if A.dtype != expected_type or B.dtype != expected_type:
1291
1292
1293
        raise TypeError(
            f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}"
        )
Tim Dettmers's avatar
Tim Dettmers committed
1294
1295
1296
1297
1298
1299
1300
1301
1302

    sA = A.shape
    sB = B.shape
    tA = transposed_A
    tB = transposed_B

    correct = True

    if len(sA) == 2 and len(sB) == 2:
1303
1304
1305
1306
1307
1308
1309
1310
        if not tA and not tB and A.shape[1] != B.shape[0]:
            correct = False
        elif tA and not tB and A.shape[0] != B.shape[0]:
            correct = False
        elif tA and tB and A.shape[0] != B.shape[1]:
            correct = False
        elif not tA and tB and A.shape[1] != B.shape[1]:
            correct = False
Tim Dettmers's avatar
Tim Dettmers committed
1311
    elif len(sA) == 3 and len(sB) == 2:
1312
1313
1314
1315
1316
1317
1318
1319
        if not tA and not tB and A.shape[2] != B.shape[0]:
            correct = False
        elif tA and not tB and A.shape[1] != B.shape[0]:
            correct = False
        elif tA and tB and A.shape[1] != B.shape[1]:
            correct = False
        elif not tA and tB and A.shape[2] != B.shape[1]:
            correct = False
Tim Dettmers's avatar
Tim Dettmers committed
1320
    elif len(sA) == 3 and len(sB) == 3:
1321
1322
1323
1324
1325
1326
1327
1328
        if not tA and not tB and A.shape[2] != B.shape[1]:
            correct = False
        elif tA and not tB and A.shape[1] != B.shape[1]:
            correct = False
        elif tA and tB and A.shape[1] != B.shape[2]:
            correct = False
        elif not tA and tB and A.shape[2] != B.shape[2]:
            correct = False
Tim Dettmers's avatar
Tim Dettmers committed
1329
1330
1331
1332
1333

    if out is not None:
        sout = out.shape
        # special case common in backprop
        if not correct and len(sA) == 3 and len(sB) == 3:
1334
1335
1336
1337
1338
1339
            if (
                sout[0] == sA[2]
                and sout[1] == sB[2]
                and sA[0] == sB[0]
                and sA[1] == sB[1]
            ):
Tim Dettmers's avatar
Tim Dettmers committed
1340
1341
1342
                correct = True
    else:
        if len(sA) == 2 and len(sB) == 2:
1343
1344
1345
1346
1347
1348
1349
1350
            if not tA and not tB:
                sout = (sA[0], sB[1])
            elif tA and tB:
                sout = (sA[1], sB[0])
            elif tA and not tB:
                sout = (sA[1], sB[1])
            elif not tA and tB:
                sout = (sA[0], sB[0])
Tim Dettmers's avatar
Tim Dettmers committed
1351
        elif len(sA) == 3 and len(sB) == 2:
1352
1353
1354
1355
1356
1357
1358
1359
            if not tA and not tB:
                sout = (sA[0], sA[1], sB[1])
            elif tA and tB:
                sout = (sA[0], sA[2], sB[0])
            elif tA and not tB:
                sout = (sA[0], sA[2], sB[1])
            elif not tA and tB:
                sout = (sA[0], sA[1], sB[0])
Tim Dettmers's avatar
Tim Dettmers committed
1360
        elif len(sA) == 3 and len(sB) == 3:
1361
1362
1363
1364
1365
1366
1367
1368
            if not tA and not tB:
                sout = (sA[0], sA[1], sB[2])
            elif tA and tB:
                sout = (sA[0], sA[2], sB[1])
            elif tA and not tB:
                sout = (sA[0], sA[2], sB[2])
            elif not tA and tB:
                sout = (sA[0], sA[1], sB[1])
Tim Dettmers's avatar
Tim Dettmers committed
1369
1370

    if not correct:
1371
1372
1373
        raise ValueError(
            f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}."
        )
Tim Dettmers's avatar
Tim Dettmers committed
1374
1375
1376

    return sout

Tim Dettmers's avatar
Tim Dettmers committed
1377
1378
1379
1380
1381
1382
def cutlass3_gemm(
    A: Tensor,
    B: Tensor,
    out: Tensor = None,
    transposed_A=False,
    transposed_B=False,
Tim Dettmers's avatar
Tim Dettmers committed
1383
    state=None
Tim Dettmers's avatar
Tim Dettmers committed
1384
):
Tim Dettmers's avatar
Tim Dettmers committed
1385
1386
1387
    #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
    if state is None:
        Bshape = B.shape
1388
        bout = Bshape[1]
Tim Dettmers's avatar
Tim Dettmers committed
1389
1390
    else:
        Bshape = state[1]
1391
        bout = Bshape[0]
Tim Dettmers's avatar
Tim Dettmers committed
1392
    if out is None:
1393
        out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
Tim Dettmers's avatar
Tim Dettmers committed
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438

    sA = A.shape
    sB = B.shape
    if transposed_A and len(sA) == 2:
        sA = (sA[1], sA[0])
    elif transposed_A and len(sA) == 3:
        sA = (sA[0], sA[2], sA[0])
    if transposed_B and len(sB) == 2:
        sB = (sB[1], sB[0])
    elif transposed_B and len(sB) == 3:
        sB = (sB[0], sB[2], sB[0])
    # this is a mess: cuBLAS expect column major, but PyTorch is row major.
    # So to perform the matrix multiplication, we have to treat A, B, and C matrices
    # (transpose of row major is column major)
    # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these

    # matrices in the input arguments for cuBLAS
    # column major: A @ B = C: [m, k] @ [k, n] = [m, n]
    # row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
    # column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
    if len(sB) == 2:
        if B.stride()[0] == B.shape[1]:
            transposed_B = False
        elif B.stride()[1] == B.shape[0]:
            transposed_B = True
        if len(A.shape) == 2:
            if A.stride()[0] == A.shape[1]:
                transposed_A = False
            elif A.stride()[1] == A.shape[0]:
                transposed_A = True
        else:
            if A.stride()[1] == A.shape[2]:
                transposed_A = False
            elif A.stride()[2] == A.shape[1]:
                transposed_A = True

        if len(sA) == 2:
            n = sA[0]
            ldb = A.stride()[1 if transposed_A else 0]
        elif len(sA) == 3 and len(sB) == 2:
            n = sA[0] * sA[1]
            ldb = sA[2]

        m = sB[1]
        k = sB[0]
Tim Dettmers's avatar
Tim Dettmers committed
1439
        lda = B.stride()[0]
Tim Dettmers's avatar
Tim Dettmers committed
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
        ldc = sB[1]
    elif len(sB) == 3:
        # special case
        assert len(sA) == 3
        if not (sA[0] == sB[0] and sA[1] == sB[1]):
            raise ValueError(
                f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}"
            )

        transposed_A = True
        transposed_B = False

        m = sB[2]
        n = sA[2]
        k = sB[0] * sB[1]

Tim Dettmers's avatar
Tim Dettmers committed
1456
        lda = n
Tim Dettmers's avatar
Tim Dettmers committed
1457
1458
1459
1460
1461
1462
1463
        ldb = sA[2]
        ldc = m

    ptr = CUBLAS_Context.get_instance().get_context(A.device)

    # B^T @ A^T = C^T
    # [km, nk -> mn]
Tim Dettmers's avatar
Tim Dettmers committed
1464
    #lda = ldb = ldc = 1
Tim Dettmers's avatar
Tim Dettmers committed
1465
    #lda = 1
Tim Dettmers's avatar
Tim Dettmers committed
1466
1467
1468
    if state is not None:
        m = Bshape[0]
        k = Bshape[1]
1469
        lda = Bshape[0]
Tim Dettmers's avatar
Tim Dettmers committed
1470
1471
        ldc = Bshape[0]
        ldb = (ldb+1)//2
Tim Dettmers's avatar
Tim Dettmers committed
1472
    #print(m, n, k, lda, ldb, ldc)
Tim Dettmers's avatar
Tim Dettmers committed
1473
1474
1475
1476
1477
1478
1479
    is_on_gpu([B, A, out])
    m = ct.c_int32(m)
    n = ct.c_int32(n)
    k = ct.c_int32(k)
    lda = ct.c_int32(lda)
    ldb = ct.c_int32(ldb)
    ldc = ct.c_int32(ldc)
Tim Dettmers's avatar
Tim Dettmers committed
1480
1481
1482
1483

    if B.dtype == torch.uint8:
        lib.cgemm_4bit_inference(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
    elif A.dtype == torch.float32:
1484
1485
1486
1487
1488
        lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)
    elif A.dtype == torch.float16:
        lib.cgemm_host_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)
    else:
        raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
Tim Dettmers's avatar
Tim Dettmers committed
1489
1490
1491
1492
1493

    return out



1494
1495

def igemm(
1496
1497
1498
1499
1500
    A: Tensor,
    B: Tensor,
    out: Tensor = None,
    transposed_A=False,
    transposed_B=False,
1501
):
Tim Dettmers's avatar
Tim Dettmers committed
1502
    sout = check_matmul(A, B, out, transposed_A, transposed_B)
1503
1504
    if out is None:
        out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
Tim Dettmers's avatar
Tim Dettmers committed
1505
1506
1507
1508
1509
1510
    if len(A.shape) == 3 and len(B.shape) == 3:
        if A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]:
            return batched_igemm(A, B, out)

    sA = A.shape
    sB = B.shape
1511
1512
1513
1514
1515
1516
1517
1518
    if transposed_A and len(sA) == 2:
        sA = (sA[1], sA[0])
    elif transposed_A and len(sA) == 3:
        sA = (sA[0], sA[2], sA[0])
    if transposed_B and len(sB) == 2:
        sB = (sB[1], sB[0])
    elif transposed_B and len(sB) == 3:
        sB = (sB[0], sB[2], sB[0])
Tim Dettmers's avatar
Tim Dettmers committed
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
    # this is a mess: cuBLAS expect column major, but PyTorch is row major.
    # So to perform the matrix multiplication, we have to treat A, B, and C matrices
    # (transpose of row major is column major)
    # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these

    # matrices in the input arguments for cuBLAS
    # column major: A @ B = C: [m, k] @ [k, n] = [m, n]
    # row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
    # column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
    if len(sB) == 2:
1529
1530
1531
1532
        if B.stride()[0] == B.shape[1]:
            transposed_B = False
        elif B.stride()[1] == B.shape[0]:
            transposed_B = True
Tim Dettmers's avatar
Tim Dettmers committed
1533
        if len(A.shape) == 2:
1534
1535
1536
1537
            if A.stride()[0] == A.shape[1]:
                transposed_A = False
            elif A.stride()[1] == A.shape[0]:
                transposed_A = True
Tim Dettmers's avatar
Tim Dettmers committed
1538
        else:
1539
1540
1541
1542
            if A.stride()[1] == A.shape[2]:
                transposed_A = False
            elif A.stride()[2] == A.shape[1]:
                transposed_A = True
Tim Dettmers's avatar
Tim Dettmers committed
1543
1544
1545
1546
1547

        if len(sA) == 2:
            n = sA[0]
            ldb = A.stride()[1 if transposed_A else 0]
        elif len(sA) == 3 and len(sB) == 2:
1548
            n = sA[0] * sA[1]
Tim Dettmers's avatar
Tim Dettmers committed
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
            ldb = sA[2]

        m = sB[1]
        k = sB[0]
        lda = B.stride()[(1 if transposed_B else 0)]
        ldc = sB[1]
    elif len(sB) == 3:
        # special case
        assert len(sA) == 3
        if not (sA[0] == sB[0] and sA[1] == sB[1]):
1559
1560
1561
            raise ValueError(
                f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
1562
1563
1564
1565
1566
1567

        transposed_A = True
        transposed_B = False

        m = sB[2]
        n = sA[2]
1568
        k = sB[0] * sB[1]
Tim Dettmers's avatar
Tim Dettmers committed
1569
1570
1571
1572
1573
1574
1575
1576

        lda = m
        ldb = sA[2]
        ldc = m

    ptr = CUBLAS_Context.get_instance().get_context(A.device)

    # B^T @ A^T = C^T
1577
    # [km, nk -> mn]
1578
    is_on_gpu([B, A, out])
Tim Dettmers's avatar
Tim Dettmers committed
1579
1580
1581
1582
1583
    lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
               get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))
    return out


1584
def batched_igemm(
1585
1586
1587
1588
1589
    A: Tensor,
    B: Tensor,
    out: Tensor = None,
    transposed_A=False,
    transposed_B=False,
1590
):
Tim Dettmers's avatar
Tim Dettmers committed
1591
    if not len(A.shape) == 3 or not len(B.shape) == 3:
1592
1593
1594
        raise ValueError(
            f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}"
        )
Tim Dettmers's avatar
Tim Dettmers committed
1595
    sout = check_matmul(A, B, out, transposed_A, transposed_B)
1596
1597
    if out is None:
        out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
Tim Dettmers's avatar
Tim Dettmers committed
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653

    if B.is_contiguous():
        lda = B.stride()[1]
        transposed_A = False
    else:
        s = B.stride()
        if s[0] != B.shape[0]:
            B = B.contiguous()
            lda = B.stride()[1]
        elif s[2] == B.shape[1]:
            transposed_A = True
            lda = B.stride()[2]
        else:
            if s[2] == 1:
                B = B.contiguous()
                lda = B.stride()[1]
            elif s[1] == 1:
                B = B.contiguous()
                lda = B.stride()[1]
            else:
                B = B.contiguous()
                lda = B.stride()[1]

    if A.is_contiguous():
        ldb = A.stride()[1]
        transposed_B = False
    else:
        s = A.stride()
        if s[0] != A.shape[0]:
            A = A.contiguous()
            ldb = A.stride()[1]
            transposed_B = False
        elif s[2] == A.shape[1]:
            ldb = A.stride()[2]
            transposed_B = True
        else:
            A = A.contiguous()
            ldb = A.stride()[1]
            transposed_B = False

    # this is a mess: cuBLAS expect column major, but PyTorch is row major.
    # So to perform the matrix multiplication, we have to treat A, B, and C matrices
    # (transpose of row major is column major)
    # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
    # matrices in the input arguments for cuBLAS

    # column major: A @ B = C: [batch, m, k] @ [batch, k, n] = [batch, m, n]
    # row major: B^T @ A^T = C^T: [batch, m, k] @ [batch, k, n] = [batch, m, n]
    # column major with row major layout: B^T @ A^T = C^T: [batch, k, m] @ [batch, n, k] = [batch, n, m]
    num_batch = A.shape[0]
    n = A.shape[1]
    m = B.shape[2]
    k = B.shape[1]

    ldc = m

1654
1655
1656
    strideA = B.shape[1] * B.shape[2]
    strideB = A.shape[1] * A.shape[2]
    strideC = A.shape[1] * B.shape[2]
Tim Dettmers's avatar
Tim Dettmers committed
1657
1658
1659

    ptr = CUBLAS_Context.get_instance().get_context(A.device)

1660
    is_on_gpu([B, A, out])
Tim Dettmers's avatar
Tim Dettmers committed
1661
1662
1663
1664
1665
    lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
               get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc),
               ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
    return out

1666

1667
def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
Tim Dettmers's avatar
Tim Dettmers committed
1668
1669
1670
1671
    shapeA = SA[0]
    shapeB = SB[0]
    dimsA = len(shapeA)
    dimsB = len(shapeB)
1672
    assert dimsB == 2, 'Only two dimensional matrices are supported for argument B'
Tim Dettmers's avatar
Tim Dettmers committed
1673
1674
1675
    if dimsA == 2:
        m = shapeA[0]
    elif dimsA == 3:
1676
        m = shapeA[0] * shapeA[1]
Tim Dettmers's avatar
Tim Dettmers committed
1677

1678
    rows = n = shapeB[0]
1679
    assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
1680
1681
1682
1683
1684
1685

    # if the tensor is empty, return a transformed empty tensor with the right dimensions
    if shapeA[0] == 0 and dimsA == 2:
        return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16)
    elif shapeA[1] == 0 and dimsA == 3:
        return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
1686
1687

    if dimsA == 2 and out is None:
1688
1689
1690
        out, Sout = get_transform_buffer(
            (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row"
        )
Tim Dettmers's avatar
Tim Dettmers committed
1691
    elif dimsA == 3 and out is None:
1692
1693
1694
        out, Sout = get_transform_buffer(
            (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row"
        )
Tim Dettmers's avatar
Tim Dettmers committed
1695

1696
1697
1698
    assert dimsB != 3, "len(B.shape)==3 not supported"
    assert A.device.type == "cuda"
    assert B.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
1699
1700
1701
    assert A.dtype == torch.int8
    assert B.dtype == torch.int8
    assert out.dtype == dtype
1702
1703
1704
1705
1706
1707
    assert SA[1] == "col32"
    assert SB[1] in ["col_turing", "col_ampere"]
    assert Sout[1] == "col32"
    assert (
        shapeA[-1] == shapeB[-1]
    ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}"
Tim Dettmers's avatar
Tim Dettmers committed
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
    formatB = SB[1]
    prev_device = A.device
    torch.cuda.set_device(A.device)

    ptr = CUBLAS_Context.get_instance().get_context(A.device)
    ptrA = get_ptr(A)
    ptrB = get_ptr(B)
    ptrC = get_ptr(out)

    k = shapeA[-1]
1718
1719
    lda = ct.c_int32(m * 32)
    if formatB == "col_turing":
Tim Dettmers's avatar
Tim Dettmers committed
1720
1721
        # turing: tiles with rows filled up to multiple of 8 rows by 32 columns
        # n = rows
1722
        ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32)
Tim Dettmers's avatar
Tim Dettmers committed
1723
1724
1725
    else:
        # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
        # n = rows
1726
        ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32)
Tim Dettmers's avatar
Tim Dettmers committed
1727

1728
    ldc = ct.c_int32(m * 32)
Tim Dettmers's avatar
Tim Dettmers committed
1729
1730
1731
1732
1733
    m = ct.c_int32(m)
    n = ct.c_int32(n)
    k = ct.c_int32(k)

    has_error = 0
1734
    ptrRowScale = get_ptr(None)
1735
    is_on_gpu([A, B, out])
Tim Dettmers's avatar
Tim Dettmers committed
1736
1737
    if formatB == 'col_turing':
        if dtype == torch.int32:
1738
1739
1740
            has_error = lib.cigemmlt_turing_32(
                ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
            )
Tim Dettmers's avatar
Tim Dettmers committed
1741
        else:
1742
1743
1744
1745
            has_error = lib.cigemmlt_turing_8(
                ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
            )
    elif formatB == "col_ampere":
Tim Dettmers's avatar
Tim Dettmers committed
1746
        if dtype == torch.int32:
1747
1748
1749
            has_error = lib.cigemmlt_ampere_32(
                ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
            )
Tim Dettmers's avatar
Tim Dettmers committed
1750
        else:
1751
1752
1753
            has_error = lib.cigemmlt_ampere_8(
                ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
            )
Tim Dettmers's avatar
Tim Dettmers committed
1754
1755

    if has_error == 1:
1756
        print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
Tim Dettmers's avatar
Tim Dettmers committed
1757
1758
1759
1760
1761
1762
1763
        raise Exception('cublasLt ran into an error!')

    torch.cuda.set_device(prev_device)

    return out, Sout


1764
1765
1766
1767
1768
1769
1770
1771
def mm_dequant(
    A,
    quant_state,
    row_stats,
    col_stats,
    out=None,
    new_row_stats=None,
    new_col_stats=None,
1772
    bias=None
1773
):
Tim Dettmers's avatar
Tim Dettmers committed
1774
    assert A.dtype == torch.int32
1775
    if bias is not None: assert bias.dtype == torch.float16
Tim Dettmers's avatar
Tim Dettmers committed
1776
    out_shape = quant_state[0]
1777
1778
1779
1780
1781
1782
    if len(out_shape) == 3:
        out_shape = (out_shape[0] * out_shape[1], out_shape[2])

    if out is None:
        out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
    if new_row_stats is None:
1783
1784
1785
        new_row_stats = torch.empty(
            out_shape[0], dtype=torch.float32, device=A.device
        )
1786
    if new_col_stats is None:
1787
1788
1789
        new_col_stats = torch.empty(
            out_shape[1], dtype=torch.float32, device=A.device
        )
1790
1791
1792
1793
1794
1795
    assert (
        new_row_stats.shape[0] == row_stats.shape[0]
    ), f"{new_row_stats.shape} vs {row_stats.shape}"
    assert (
        new_col_stats.shape[0] == col_stats.shape[0]
    ), f"{new_col_stats.shape} vs {col_stats.shape}"
Tim Dettmers's avatar
Tim Dettmers committed
1796

1797
    prev_device = pre_call(A.device)
Tim Dettmers's avatar
Tim Dettmers committed
1798
1799
1800
1801
1802
1803
    ptrA = get_ptr(A)
    ptrOut = get_ptr(out)
    ptrRowStats = get_ptr(row_stats)
    ptrColStats = get_ptr(col_stats)
    ptrNewRowStats = get_ptr(new_row_stats)
    ptrNewColStats = get_ptr(new_col_stats)
1804
    ptrBias = get_ptr(bias)
Tim Dettmers's avatar
Tim Dettmers committed
1805
1806
1807
    numRows = ct.c_int32(out_shape[0])
    numCols = ct.c_int32(out_shape[1])

1808
1809
1810
    is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias])
    lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols)
    post_call(prev_device)
Tim Dettmers's avatar
Tim Dettmers committed
1811
1812
1813
1814

    return out


1815
1816
1817
def get_colrow_absmax(
    A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0
):
Tim Dettmers's avatar
Tim Dettmers committed
1818
1819
1820
1821
1822
    assert A.dtype == torch.float16
    device = A.device

    cols = A.shape[-1]
    if len(A.shape) == 3:
1823
        rows = A.shape[0] * A.shape[1]
Tim Dettmers's avatar
Tim Dettmers committed
1824
1825
1826
    else:
        rows = A.shape[0]

1827
1828
1829
    col_tiles = (cols + 255) // 256
    tiled_rows = ((rows + 15) // 16) * 16
    if row_stats is None:
1830
1831
1832
        row_stats = torch.empty(
            (rows,), dtype=torch.float32, device=device
        ).fill_(-50000.0)
1833
    if col_stats is None:
1834
1835
1836
        col_stats = torch.empty(
            (cols,), dtype=torch.float32, device=device
        ).fill_(-50000.0)
1837
1838
1839
1840
1841

    if nnz_block_ptr is None and threshold > 0.0:
        nnz_block_ptr = torch.zeros(
            ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device
        )
Tim Dettmers's avatar
Tim Dettmers committed
1842
1843
1844
1845
1846
1847
1848
1849
1850

    ptrA = get_ptr(A)
    ptrRowStats = get_ptr(row_stats)
    ptrColStats = get_ptr(col_stats)
    ptrNnzrows = get_ptr(nnz_block_ptr)
    rows = ct.c_int32(rows)
    cols = ct.c_int32(cols)

    prev_device = pre_call(A.device)
1851
    is_on_gpu([A, row_stats, col_stats, nnz_block_ptr])
Tim Dettmers's avatar
Tim Dettmers committed
1852
1853
1854
1855
1856
1857
1858
1859
    lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols)
    post_call(prev_device)

    if threshold > 0.0:
        nnz_block_ptr.cumsum_(0)

    return row_stats, col_stats, nnz_block_ptr

1860

1861
class COOSparseTensor:
Tim Dettmers's avatar
Tim Dettmers committed
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
    def __init__(self, rows, cols, nnz, rowidx, colidx, values):
        assert rowidx.dtype == torch.int32
        assert colidx.dtype == torch.int32
        assert values.dtype == torch.float16
        assert values.numel() == nnz
        assert rowidx.numel() == nnz
        assert colidx.numel() == nnz

        self.rows = rows
        self.cols = cols
        self.nnz = nnz
        self.rowidx = rowidx
        self.colidx = colidx
        self.values = values

1877

1878
class CSRSparseTensor:
Tim Dettmers's avatar
Tim Dettmers committed
1879
1880
1881
1882
1883
1884
    def __init__(self, rows, cols, nnz, rowptr, colidx, values):
        assert rowptr.dtype == torch.int32
        assert colidx.dtype == torch.int32
        assert values.dtype == torch.float16
        assert values.numel() == nnz
        assert colidx.numel() == nnz
1885
        assert rowptr.numel() == rows + 1
Tim Dettmers's avatar
Tim Dettmers committed
1886
1887
1888
1889
1890
1891
1892
1893

        self.rows = rows
        self.cols = cols
        self.nnz = nnz
        self.rowptr = rowptr
        self.colidx = colidx
        self.values = values

1894

1895
class CSCSparseTensor:
Tim Dettmers's avatar
Tim Dettmers committed
1896
1897
1898
1899
1900
1901
    def __init__(self, rows, cols, nnz, colptr, rowidx, values):
        assert colptr.dtype == torch.int32
        assert rowidx.dtype == torch.int32
        assert values.dtype == torch.float16
        assert values.numel() == nnz
        assert rowidx.numel() == nnz
1902
        assert colptr.numel() == cols + 1
Tim Dettmers's avatar
Tim Dettmers committed
1903
1904
1905
1906
1907
1908
1909
1910

        self.rows = rows
        self.cols = cols
        self.nnz = nnz
        self.colptr = colptr
        self.rowidx = rowidx
        self.values = values

1911

Tim Dettmers's avatar
Tim Dettmers committed
1912
1913
1914
def coo2csr(cooA):
    values, counts = torch.unique(cooA.rowidx, return_counts=True)
    values.add_(1)
1915
1916
1917
    rowptr = torch.zeros(
        (cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device
    )
Tim Dettmers's avatar
Tim Dettmers committed
1918
1919
    rowptr.scatter_(index=values.long(), src=counts.int(), dim=0)
    rowptr.cumsum_(0)
1920
1921
1922
1923
    return CSRSparseTensor(
        cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values
    )

Tim Dettmers's avatar
Tim Dettmers committed
1924
1925
1926
1927
1928
1929
1930

def coo2csc(cooA):
    val, col2rowidx = torch.sort(cooA.colidx)
    rowidx = cooA.rowidx[col2rowidx]
    values = cooA.values[col2rowidx]
    colvalues, counts = torch.unique(val, return_counts=True)
    colvalues.add_(1)
1931
1932
1933
    colptr = torch.zeros(
        (cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device
    )
Tim Dettmers's avatar
Tim Dettmers committed
1934
1935
    colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0)
    colptr.cumsum_(0)
1936
1937
1938
    return CSCSparseTensor(
        cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values
    )
Tim Dettmers's avatar
Tim Dettmers committed
1939

1940

Tim Dettmers's avatar
Tim Dettmers committed
1941
1942
1943
1944
1945
1946
1947
def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
    rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device)
    colidx = torch.zeros((nnz,), dtype=torch.int32, device=device)
    values = torch.zeros((nnz,), dtype=dtype, device=device)
    return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)


1948
1949
1950
def double_quant(
    A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0
):
Tim Dettmers's avatar
Tim Dettmers committed
1951
1952
    device = A.device
    assert A.dtype == torch.half
1953
    assert device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
1954
1955
1956
1957
    prev_device = pre_call(A.device)

    cols = A.shape[-1]
    if len(A.shape) == 3:
1958
        rows = A.shape[0] * A.shape[1]
Tim Dettmers's avatar
Tim Dettmers committed
1959
1960
1961
1962
    else:
        rows = A.shape[0]

    if row_stats is None or col_stats is None:
1963
1964
1965
        row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(
            A, threshold=threshold
        )
Tim Dettmers's avatar
Tim Dettmers committed
1966

1967
1968
1969
1970
    if out_col is None:
        out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
    if out_row is None:
        out_row = torch.zeros(A.shape, device=device, dtype=torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
1971
1972
1973
1974
1975
1976
1977
1978

    coo_tensor = None
    ptrA = get_ptr(A)
    ptrColStats = get_ptr(col_stats)
    ptrRowStats = get_ptr(row_stats)
    ptrOutCol = get_ptr(out_col)
    ptrOutRow = get_ptr(out_row)

1979
    is_on_gpu([A, col_stats, row_stats, out_col, out_row])
Tim Dettmers's avatar
Tim Dettmers committed
1980
1981
1982
    if threshold > 0.0:
        nnz = nnz_row_ptr[-1].item()
        if nnz > 0:
1983
1984
1985
            coo_tensor = coo_zeros(
                A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device
            )
Tim Dettmers's avatar
Tim Dettmers committed
1986
1987
1988
1989
1990
            ptrRowIdx = get_ptr(coo_tensor.rowidx)
            ptrColIdx = get_ptr(coo_tensor.colidx)
            ptrVal = get_ptr(coo_tensor.values)
            ptrRowPtr = get_ptr(nnz_row_ptr)

1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
            lib.cdouble_rowcol_quant(
                ptrA,
                ptrRowStats,
                ptrColStats,
                ptrOutCol,
                ptrOutRow,
                ptrRowIdx,
                ptrColIdx,
                ptrVal,
                ptrRowPtr,
                ct.c_float(threshold),
                ct.c_int32(rows),
                ct.c_int32(cols),
            )
Tim Dettmers's avatar
Tim Dettmers committed
2005
2006
2007
2008
2009
            val, idx = torch.sort(coo_tensor.rowidx)
            coo_tensor.rowidx = val
            coo_tensor.colidx = coo_tensor.colidx[idx]
            coo_tensor.values = coo_tensor.values[idx]
        else:
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
            lib.cdouble_rowcol_quant(
                ptrA,
                ptrRowStats,
                ptrColStats,
                ptrOutCol,
                ptrOutRow,
                None,
                None,
                None,
                None,
                ct.c_float(0.0),
                ct.c_int32(rows),
                ct.c_int32(cols),
            )
Tim Dettmers's avatar
Tim Dettmers committed
2024
    else:
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
        lib.cdouble_rowcol_quant(
            ptrA,
            ptrRowStats,
            ptrColStats,
            ptrOutCol,
            ptrOutRow,
            None,
            None,
            None,
            None,
            ct.c_float(threshold),
            ct.c_int32(rows),
            ct.c_int32(cols),
        )
Tim Dettmers's avatar
Tim Dettmers committed
2039
2040
2041
2042
2043
2044
    post_call(prev_device)

    return out_row, out_col, row_stats, col_stats, coo_tensor


def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
2045
    prev_device = pre_call(A.device)
Tim Dettmers's avatar
Tim Dettmers committed
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
    if state is None: state = (A.shape, from_order)
    else: from_order = state[1]
    if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
    else: new_state = (state[0], to_order) # (shape, order)

    shape = state[0]
    if len(shape) == 2:
        dim1 = ct.c_int32(shape[0])
        dim2 = ct.c_int32(shape[1])
    else:
2056
        dim1 = ct.c_int32(shape[0] * shape[1])
Tim Dettmers's avatar
Tim Dettmers committed
2057
2058
        dim2 = ct.c_int32(shape[2])

2059
    is_on_gpu([A, out])
Tim Dettmers's avatar
Tim Dettmers committed
2060
2061
2062
2063
2064
    if to_order == 'col32':
        if transpose:
            lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
        else:
            lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2)
2065
    elif to_order == "col_turing":
Tim Dettmers's avatar
Tim Dettmers committed
2066
2067
2068
2069
        if transpose:
            lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2)
        else:
            lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2)
2070
    elif to_order == "col_ampere":
Tim Dettmers's avatar
Tim Dettmers committed
2071
2072
2073
2074
        if transpose:
            lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2)
        else:
            lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2)
2075
2076
    elif to_order == "row":
        if from_order == "col_turing":
Tim Dettmers's avatar
Tim Dettmers committed
2077
            lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2)
2078
        elif from_order == "col_ampere":
Tim Dettmers's avatar
Tim Dettmers committed
2079
2080
2081
2082
            lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
    else:
        raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')

2083
    post_call(prev_device)
Tim Dettmers's avatar
Tim Dettmers committed
2084
2085
2086

    return out, new_state

2087

Tim Dettmers's avatar
Tim Dettmers committed
2088
def spmm_coo(cooA, B, out=None):
2089
    if out is None:
2090
2091
2092
        out = torch.empty(
            (cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype
        )
Tim Dettmers's avatar
Tim Dettmers committed
2093
2094
2095
2096
2097
2098
    nnz = cooA.nnz
    assert cooA.rowidx.numel() == nnz
    assert cooA.colidx.numel() == nnz
    assert cooA.values.numel() == nnz
    assert cooA.cols == B.shape[0]

2099
    transposed_B = False if B.is_contiguous() else True
Tim Dettmers's avatar
Tim Dettmers committed
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117

    ldb = B.stride()[(1 if transposed_B else 0)]
    ldc = B.shape[1]

    ptr = Cusparse_Context.get_instance().context

    ptrRowidx = get_ptr(cooA.rowidx)
    ptrColidx = get_ptr(cooA.colidx)
    ptrValues = get_ptr(cooA.values)
    ptrB = get_ptr(B)
    ptrC = get_ptr(out)
    cnnz = ct.c_int32(cooA.nnz)
    crowsA = ct.c_int32(cooA.rows)
    ccolsA = ct.c_int32(cooA.cols)
    ccolsB = ct.c_int32(B.shape[1])
    cldb = ct.c_int32(ldb)
    cldc = ct.c_int32(ldc)

2118
    is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out])
Tim Dettmers's avatar
Tim Dettmers committed
2119
2120
2121
2122
    lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B))

    return out

2123

Tim Dettmers's avatar
Tim Dettmers committed
2124
def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
2125
2126
2127
2128
    if out is None:
        out = torch.zeros(
            (cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype
        )
Tim Dettmers's avatar
Tim Dettmers committed
2129
2130
2131
2132
    nnz = cooA.nnz
    assert cooA.rowidx.numel() == nnz
    assert cooA.colidx.numel() == nnz
    assert cooA.values.numel() == nnz
2133
    assert cooA.cols == B.shape[0], f"{cooA.cols} vs {B.shape}"
Tim Dettmers's avatar
Tim Dettmers committed
2134

2135
    transposed_B = False if B.is_contiguous() else True
Tim Dettmers's avatar
Tim Dettmers committed
2136
2137
2138
2139
2140
2141
2142
2143
2144

    ldb = B.stride()[(1 if transposed_B else 0)]
    ldc = B.shape[1]

    values, counts = torch.unique(cooA.rowidx, return_counts=True)
    offset = counts.cumsum(0).int()
    max_count, max_idx = torch.sort(counts, descending=True)
    max_idx = max_idx.int()
    max_count = max_count.int()
2145
2146
2147
    assert (
        max_count[0] <= 32
    ), f"Current max count per row is 8 but found {max_count[0]}."
Tim Dettmers's avatar
Tim Dettmers committed
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
    assert B.dtype in [torch.float16, torch.int8]
    ptrOffset = get_ptr(offset)
    ptrMaxCount = get_ptr(max_count)
    ptrMaxIdx = get_ptr(max_idx)

    ptrRowidx = get_ptr(cooA.rowidx)
    ptrColidx = get_ptr(cooA.colidx)
    ptrValues = get_ptr(cooA.values)
    ptrB = get_ptr(B)
    ptrC = get_ptr(out)
    ptrDequantStats = get_ptr(dequant_stats)
    cnnz_rows = ct.c_int32(counts.numel())
    cnnz = ct.c_int32(cooA.nnz)
    crowsA = ct.c_int32(cooA.rows)
    ccolsA = ct.c_int32(cooA.cols)
    crowsB = ct.c_int32(B.shape[1])
    ccolsB = ct.c_int32(B.shape[1])
    cldb = ct.c_int32(ldb)
    cldc = ct.c_int32(ldc)

2168
    is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
Tim Dettmers's avatar
Tim Dettmers committed
2169
    if B.dtype == torch.float16:
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
        lib.cspmm_coo_very_sparse_naive_fp16(
            ptrMaxCount,
            ptrMaxIdx,
            ptrOffset,
            ptrRowidx,
            ptrColidx,
            ptrValues,
            ptrB,
            ptrC,
            ptrDequantStats,
            cnnz_rows,
            cnnz,
            crowsA,
            crowsB,
            ccolsB,
        )
Tim Dettmers's avatar
Tim Dettmers committed
2186
    elif B.dtype == torch.int8:
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
        lib.cspmm_coo_very_sparse_naive_int8(
            ptrMaxCount,
            ptrMaxIdx,
            ptrOffset,
            ptrRowidx,
            ptrColidx,
            ptrValues,
            ptrB,
            ptrC,
            ptrDequantStats,
            cnnz_rows,
            cnnz,
            crowsA,
            crowsB,
            ccolsB,
        )
    # else: assertion error
Tim Dettmers's avatar
Tim Dettmers committed
2204
2205
2206
2207
2208
2209

    return out


C = 127.0

2210
2211
2212

def vectorwise_quant(x, dim=1, quant_type="vector"):
    if quant_type == "linear":
Tim Dettmers's avatar
Tim Dettmers committed
2213
        max1 = torch.abs(x).max().float()
2214
        xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
2215
        return xq, max1
2216
    elif quant_type in ["vector", "row"]:
Tim Dettmers's avatar
Tim Dettmers committed
2217
        max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
2218
        xq = torch.round(x * (C / max1)).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
2219
        return xq, max1
2220
    elif quant_type == "zeropoint":
Tim Dettmers's avatar
Tim Dettmers committed
2221
2222
2223
        dtype = x.dtype
        x = x.float()
        dyna = x.max() - x.min()
2224
2225
2226
        if dyna == 0:
            dyna = 1
        qx = 255.0 / dyna
Tim Dettmers's avatar
Tim Dettmers committed
2227
        minx = x.min()
2228
2229
        zpx = torch.round(minx * qx)
        x = torch.round(qx * x - zpx) + zpx
Tim Dettmers's avatar
Tim Dettmers committed
2230
        return x, qx
2231
    elif quant_type in ["vector-zeropoint", "row-zeropoint"]:
Tim Dettmers's avatar
Tim Dettmers committed
2232
2233
        dtype = x.dtype
        x = x.float()
2234
2235
2236
2237
2238
        dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(
            x, dim=dim, keepdim=True
        )
        dyna[dyna == 0] = 1
        qx = 255.0 / dyna
Tim Dettmers's avatar
Tim Dettmers committed
2239
        minx = torch.amin(x, dim=dim, keepdim=True)
2240
2241
        zpx = torch.round(minx * qx)
        x = torch.round(qx * x - zpx) + zpx
Tim Dettmers's avatar
Tim Dettmers committed
2242
        return x, qx
2243
    elif quant_type == "truncated-vector":
Tim Dettmers's avatar
Tim Dettmers committed
2244
2245
2246
        with torch.no_grad():
            absx = torch.abs(x)
            max1 = torch.amax(absx, dim=dim, keepdim=True)
2247
2248
            max1 = max1 * 0.7
            idx = absx > max1.expand_as(absx)
Tim Dettmers's avatar
Tim Dettmers committed
2249
            sign = torch.sign(x[idx])
2250
2251
            x[idx] = max1.expand_as(absx)[idx] * sign
            xq = torch.round(x / max1 * C).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
2252
        return xq, max1
2253
2254
2255
    else:
        return None

Tim Dettmers's avatar
Tim Dettmers committed
2256

2257
2258
2259
def vectorwise_dequant(xq, max1, quant_type="vector"):
    if quant_type == "vector":
        x = (xq / C * max1).to(torch.float32)
Tim Dettmers's avatar
Tim Dettmers committed
2260
        return x
2261
2262
    else:
        return None
Tim Dettmers's avatar
Tim Dettmers committed
2263

2264
2265
2266
2267

def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
    if quant_type == "linear":
        norm = S1 * S2 / (C * C)
Tim Dettmers's avatar
Tim Dettmers committed
2268
        # double cast needed to prevent overflows
2269
2270
2271
2272
2273
2274
        return (xq.float() * norm).to(dtype)
    elif quant_type == "zeropoint":
        norm = 1.0 / (S1 * S2)
        return (xq.float() * norm).to(dtype)
    elif quant_type == "row-zeropoint":
        norm = 1.0 / (S1 * S2)
Tim Dettmers's avatar
Tim Dettmers committed
2275
        x = xq.float()
2276
2277
2278
2279
        if len(S1.shape) == 3 and len(x.shape) == 2:
            S1 = S1.squeeze(0)
        if len(S2.shape) == 3 and len(x.shape) == 2:
            S2 = S2.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
2280
2281
2282
2283
2284
        if len(S1.shape) == 2:
            x *= norm
        else:
            x *= norm
        return x.to(dtype)
2285
    elif quant_type == "vector-zeropoint":
Tim Dettmers's avatar
Tim Dettmers committed
2286
        x = xq.float()
2287
2288
2289
2290
        if len(S1.shape) == 3 and len(x.shape) == 2:
            S1 = S1.squeeze(0)
        if len(S2.shape) == 3 and len(x.shape) == 2:
            S2 = S2.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
2291
        if len(S1.shape) == 2:
2292
            x *= 1.0 / S1
Tim Dettmers's avatar
Tim Dettmers committed
2293
        else:
2294
2295
            x *= 1.0 / S1
        x *= 1.0 / S2.t()
Tim Dettmers's avatar
Tim Dettmers committed
2296
        return x.to(dtype)
2297
    elif quant_type == "row":
Tim Dettmers's avatar
Tim Dettmers committed
2298
        x = xq.float()
2299
2300
2301
2302
        if len(S1.shape) == 3 and len(x.shape) == 2:
            S1 = S1.squeeze(0)
        if len(S2.shape) == 3 and len(x.shape) == 2:
            S2 = S2.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
2303
        if len(S1.shape) == 2:
2304
            x *= S1 * S2 / (C * C)
Tim Dettmers's avatar
Tim Dettmers committed
2305
        else:
2306
            x *= S1 * S2 / (C * C)
Tim Dettmers's avatar
Tim Dettmers committed
2307
        return x.to(dtype)
2308
    elif quant_type in ["truncated-vector", "vector"]:
Tim Dettmers's avatar
Tim Dettmers committed
2309
        x = xq.float()
2310
2311
2312
2313
        if len(S1.shape) == 3 and len(x.shape) == 2:
            S1 = S1.squeeze(0)
        if len(S2.shape) == 3 and len(x.shape) == 2:
            S2 = S2.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
2314
        if len(S1.shape) == 2:
2315
            x *= S1 / C
Tim Dettmers's avatar
Tim Dettmers committed
2316
        else:
2317
2318
            x *= S1 / C
        x *= S2 / C
Tim Dettmers's avatar
Tim Dettmers committed
2319
        return x.to(dtype)
2320
2321
    else:
        return None
Tim Dettmers's avatar
Tim Dettmers committed
2322
2323
2324


def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
2325
    offset = B.float().t().sum(0) * (SA[0] + SA[1])
Tim Dettmers's avatar
Tim Dettmers committed
2326
    x = xq.float()
2327
2328
    if len(xq.shape) == 2 and len(SB.shape) == 3:
        SB = SB.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
2329
    if len(SB.shape) == 2:
2330
        x *= SB.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
2331
    else:
2332
2333
2334
        x *= SB / 127
    x *= SA[1] / 127
    x += offset
Tim Dettmers's avatar
Tim Dettmers committed
2335
    return x.to(dtype)
2336

2337

2338
2339
2340
def extract_outliers(A, SA, idx):
    shapeA = SA[0]
    formatA = SA[1]
2341
2342
    assert formatA in ["col_turing", "col_ampere"]
    assert A.device.type == "cuda"
2343

2344
2345
2346
    out = torch.zeros(
        (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device
    )
2347
2348
2349
2350
2351
2352
2353
2354

    idx_size = ct.c_int32(idx.numel())
    rows = ct.c_int32(shapeA[0])
    cols = ct.c_int32(shapeA[1])
    ptrA = get_ptr(A)
    ptrIdx = get_ptr(idx)
    ptrOut = get_ptr(out)

2355
    prev_device = pre_call(A.device)
2356
2357
    if formatA == 'col_turing':
        lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
2358
    elif formatA == "col_ampere":
2359
        lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
2360
    post_call(prev_device)
2361
2362

    return out
Tim Dettmers's avatar
Tim Dettmers committed
2363
2364
2365
2366
2367

def pipeline_test(A, batch_size):
    out = torch.zeros_like(A)
    lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size))
    return out