ops.py 28.7 KB
Newer Older
wuxk1's avatar
wuxk1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
"""
    This file is part of ComfyUI.
    Copyright (C) 2024 Stability AI

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""

import torch
import logging
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
import contextlib

lifu's avatar
lifu committed
27
28
29
import triton
import triton.language as tl
from triton.language.extra import libdevice
wuxk1's avatar
wuxk1 committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323

try:
    from lmslim import quant_ops
    import lmslimquant
    from lmslim.layers.gemm.int8_utils import per_token_quant_int8
except Exception:
    print("INFO: Please install lmslim if you want to infergptq or awq or w8a8 model")


def scaled_dot_product_attention(q, k, v, *args, **kwargs):
    return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)


try:
    if torch.cuda.is_available():
        from torch.nn.attention import SDPBackend, sdpa_kernel
        import inspect
        if "set_priority" in inspect.signature(sdpa_kernel).parameters:
            SDPA_BACKEND_PRIORITY = [
                SDPBackend.FLASH_ATTENTION,
                SDPBackend.EFFICIENT_ATTENTION,
                SDPBackend.MATH,
            ]

            SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)

            def scaled_dot_product_attention(q, k, v, *args, **kwargs):
                with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
                    return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
        else:
            logging.warning("Torch version too old to set sdpa backend priority.")
except (ModuleNotFoundError, TypeError):
    logging.warning("Could not set sdpa backend priority.")

cast_to = comfy.model_management.cast_to #TODO: remove once no more references

def cast_to_input(weight, input, non_blocking=False, copy=True):
    return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)

def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
    if input is not None:
        if dtype is None:
            dtype = input.dtype
        if bias_dtype is None:
            bias_dtype = dtype
        if device is None:
            device = input.device

    offload_stream = comfy.model_management.get_offload_stream(device)
    if offload_stream is not None:
        wf_context = offload_stream
    else:
        wf_context = contextlib.nullcontext()

    bias = None
    non_blocking = comfy.model_management.device_supports_non_blocking(device)
    if s.bias is not None:
        has_function = len(s.bias_function) > 0
        bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)

        if has_function:
            with wf_context:
                for f in s.bias_function:
                    bias = f(bias)

    has_function = len(s.weight_function) > 0
    weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
    if has_function:
        with wf_context:
            for f in s.weight_function:
                weight = f(weight)

    comfy.model_management.sync_stream(device, offload_stream)
    return weight, bias

class CastWeightBiasOp:
    comfy_cast_weights = False
    weight_function = []
    bias_function = []

class disable_weight_init:
    class Linear(torch.nn.Linear, CastWeightBiasOp):
        def reset_parameters(self):
            return None

        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.linear(input, weight, bias)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

    class Conv1d(torch.nn.Conv1d, CastWeightBiasOp):
        def reset_parameters(self):
            return None

        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return self._conv_forward(input, weight, bias)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

    class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
        def reset_parameters(self):
            return None

        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return self._conv_forward(input, weight, bias)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

    class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
        def reset_parameters(self):
            return None

        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return self._conv_forward(input, weight, bias)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

    class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
        def reset_parameters(self):
            return None

        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

    class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
        def reset_parameters(self):
            return None

        def forward_comfy_cast_weights(self, input):
            if self.weight is not None:
                weight, bias = cast_bias_weight(self, input)
            else:
                weight = None
                bias = None
            return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

    class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
        def reset_parameters(self):
            self.bias = None
            return None

        def forward_comfy_cast_weights(self, input):
            if self.weight is not None:
                weight, bias = cast_bias_weight(self, input)
            else:
                weight = None
            return comfy.rmsnorm.rms_norm(input, weight, self.eps)  # TODO: switch to commented out line when old torch is deprecated
            # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

    class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
        def reset_parameters(self):
            return None

        def forward_comfy_cast_weights(self, input, output_size=None):
            num_spatial_dims = 2
            output_padding = self._output_padding(
                input, output_size, self.stride, self.padding, self.kernel_size,
                num_spatial_dims, self.dilation)

            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.conv_transpose2d(
                input, weight, bias, self.stride, self.padding,
                output_padding, self.groups, self.dilation)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

    class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp):
        def reset_parameters(self):
            return None

        def forward_comfy_cast_weights(self, input, output_size=None):
            num_spatial_dims = 1
            output_padding = self._output_padding(
                input, output_size, self.stride, self.padding, self.kernel_size,
                num_spatial_dims, self.dilation)

            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.conv_transpose1d(
                input, weight, bias, self.stride, self.padding,
                output_padding, self.groups, self.dilation)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

    class Embedding(torch.nn.Embedding, CastWeightBiasOp):
        def reset_parameters(self):
            self.bias = None
            return None

        def forward_comfy_cast_weights(self, input, out_dtype=None):
            output_dtype = out_dtype
            if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
                out_dtype = None
            weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
            return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                if "out_dtype" in kwargs:
                    kwargs.pop("out_dtype")
                return super().forward(*args, **kwargs)

    @classmethod
    def conv_nd(s, dims, *args, **kwargs):
        if dims == 2:
            return s.Conv2d(*args, **kwargs)
        elif dims == 3:
            return s.Conv3d(*args, **kwargs)
        else:
            raise ValueError(f"unsupported dimensions: {dims}")


class manual_cast(disable_weight_init):
    class Linear(disable_weight_init.Linear):
        comfy_cast_weights = True

    class Conv1d(disable_weight_init.Conv1d):
        comfy_cast_weights = True

    class Conv2d(disable_weight_init.Conv2d):
        comfy_cast_weights = True

    class Conv3d(disable_weight_init.Conv3d):
        comfy_cast_weights = True

    class GroupNorm(disable_weight_init.GroupNorm):
        comfy_cast_weights = True

    class LayerNorm(disable_weight_init.LayerNorm):
        comfy_cast_weights = True

    class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
        comfy_cast_weights = True

    class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
        comfy_cast_weights = True

    class RMSNorm(disable_weight_init.RMSNorm):
        comfy_cast_weights = True

    class Embedding(disable_weight_init.Embedding):
        comfy_cast_weights = True



from typing import Optional
lifu's avatar
lifu committed
324
class manual_cast_int8(manual_cast):
wuxk1's avatar
wuxk1 committed
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    class Linear(torch.nn.Module):
        def __init__(self, in_features, out_features, bias=True, dtype=None, device=None):
            super().__init__()
            self.in_features = in_features
            self.out_features = out_features
            self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), dtype=dtype, device=device), requires_grad=False)
            if bias:
                self.bias = torch.nn.Parameter(torch.empty(out_features, dtype=dtype, device=device))
            else:
                self.register_parameter("bias", None)

            self.weight_quant = None
            self.weight_scale = None

        def blaslt_scaled_mm(self,
                             a: torch.Tensor,
                             b: torch.Tensor,
                             scale_a: torch.Tensor,
                             scale_b: torch.Tensor,
                             out_dtype: torch.dtype,
                             bias: Optional[torch.Tensor] = None) -> torch.Tensor:
            m = a.shape[0]
            n = b.shape[0]
            k = a.shape[1]
            _, out = quant_ops.hipblaslt_w8a8_gemm(a, b, scale_a.to(torch.float32), scale_b.to(torch.float32), m, n, k, 'NT', out_dtype)
            if bias is not None:
                out += bias
            return out

        def weight_quant_int8(self, weight):
            org_w_shape = weight.shape
            w = weight.to(torch.bfloat16)
            max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
            qmin, qmax = -128, 127
            scales = (max_val / qmax).float()
            w_q = torch.clamp(torch.round(w / scales), qmin, qmax).to(torch.int8)

            assert torch.isnan(scales).sum() == 0
            assert torch.isnan(w_q).sum() == 0

            scales = scales.view(org_w_shape[0], -1)
            w_q = w_q.reshape(org_w_shape)

            return w_q, scales

        def forward(self, input):
            dim = input.dim()
            if dim > 2:
                input = input.squeeze(0)

            if self.weight_quant is None:
                self.weight_quant, self.weight_scale = self.weight_quant_int8(self.weight)
                self.bias = torch.nn.Parameter(self.bias.to(input.dtype))

            input_quant, input_scale = per_token_quant_int8(input)
            output_tensor = self.blaslt_scaled_mm(input_quant, self.weight_quant, input_scale, self.weight_scale, input.dtype, self.bias)

            if dim > 2:
                output_tensor = output_tensor.unsqueeze(0)

            return output_tensor

lifu's avatar
lifu committed
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
@triton.jit
def _per_token_quant_int8(
    x_ptr,
    xq_ptr,
    s_ptr,
    scale_ptr,
    stride_x,
    stride_xq,
    N,
    BLOCK: tl.constexpr,
):
    row_id = tl.program_id(0)

    cols = tl.arange(0, BLOCK)
    mask = cols < N

    x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
                other=0.0).to(tl.float32)

    s = tl.load(s_ptr + cols, mask=mask, other=0.0).to(tl.float32)

    x = x * s

    absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
    scale_x = absmax / 127
    x_q = x * (127 / absmax)
    x_q = libdevice.nearbyint(x_q).to(tl.int8)

    tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
    tl.store(scale_ptr + row_id, scale_x)


def per_token_quant_int8_smooth(x, s):
    M = x.numel() // x.shape[-1]
    N = x.shape[-1]
    x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
    scales = torch.empty(x.shape[:-1] + (1, ),
                         device=x.device,
                         dtype=torch.float32)
    BLOCK = triton.next_power_of_2(N)
    # heuristics for number of warps
    num_warps = min(max(BLOCK // 256, 1), 8)

    _per_token_quant_int8[(M, )](
        x,
        x_q,
        s,
        scales,
        stride_x=x.stride(-2),
        stride_xq=x_q.stride(-2),
        N=N,
        BLOCK=BLOCK,
        num_warps=num_warps,
        num_stages=1,
    )

    return x_q, scales



class manual_cast_int8_smooth(manual_cast):
    class Linear(torch.nn.Module):
        def __init__(self, in_features, out_features, bias=True, dtype=None, device=None):
wuxk1's avatar
wuxk1 committed
450
451
452
            super().__init__()
            self.in_features = in_features
            self.out_features = out_features
lifu's avatar
lifu committed
453
            self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), dtype=dtype, device=device), requires_grad=False)
wuxk1's avatar
wuxk1 committed
454
            if bias:
lifu's avatar
lifu committed
455
                self.bias = torch.nn.Parameter(torch.empty(out_features, dtype=dtype, device=device))
wuxk1's avatar
wuxk1 committed
456
            else:
lifu's avatar
lifu committed
457
                self.register_parameter("bias", None)
wuxk1's avatar
wuxk1 committed
458

lifu's avatar
lifu committed
459
460
461
            self.weight_quant = None
            self.weight_scale = None
            self.scales_rcp = None
wuxk1's avatar
wuxk1 committed
462

lifu's avatar
lifu committed
463
464
465
            self.act_scales = None
            self.count = 0
            self.alpha = 0.6
wuxk1's avatar
wuxk1 committed
466

lifu's avatar
lifu committed
467
            self.scales = torch.nn.Parameter(torch.empty(in_features, dtype=dtype, device=device), requires_grad=False)
wuxk1's avatar
wuxk1 committed
468
469

        def blaslt_scaled_mm(self,
lifu's avatar
lifu committed
470
471
472
473
474
475
                             a: torch.Tensor,
                             b: torch.Tensor,
                             scale_a: torch.Tensor,
                             scale_b: torch.Tensor,
                             out_dtype: torch.dtype,
                             bias: Optional[torch.Tensor] = None) -> torch.Tensor:
wuxk1's avatar
wuxk1 committed
476
477
478
            m = a.shape[0]
            n = b.shape[0]
            k = a.shape[1]
lifu's avatar
lifu committed
479
            _, out = quant_ops.hipblaslt_w8a8_gemm(a, b, scale_a.to(torch.float32), scale_b.to(torch.float32), m, n, k, 'NT', out_dtype)
wuxk1's avatar
wuxk1 committed
480
            if bias is not None:
lifu's avatar
lifu committed
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
                out += bias
            return out

        def weight_quant_int8(self, weight):
            org_w_shape = weight.shape
            w = weight.to(torch.bfloat16)
            max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
            qmin, qmax = -128, 127
            scales = (max_val / qmax).float()
            w_q = torch.clamp(torch.round(w / scales), qmin, qmax).to(torch.int8)

            assert torch.isnan(scales).sum() == 0
            assert torch.isnan(w_q).sum() == 0

            scales = scales.view(org_w_shape[0], -1)
            w_q = w_q.reshape(org_w_shape)

            return w_q, scales

        def per_token_quant_int8_torch(self, input):
            org_input_shape = input.shape
            max_val = input.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
            qmin, qmax = -128, 127
            scales = max_val / qmax
            input_q = torch.clamp(torch.round(input / scales), qmin, qmax).to(torch.int8)

            assert torch.isnan(scales).sum() == 0
            assert torch.isnan(input_q).sum() == 0

            return input_q, scales

        def forward(self, input):
            #return self.forward_calibration(input)

            dim = input.dim()
wuxk1's avatar
wuxk1 committed
516
            if dim > 2:
lifu's avatar
lifu committed
517
518
519
520
521
522
523
524
525
526
                input = input.squeeze(0)

            if self.weight_quant is None:
                weight_smooth = self.weight * self.scales
                self.scales_rcp = 1.0 / self.scales
                self.weight_quant, self.weight_scale = per_token_quant_int8(weight_smooth)
                del self.weight

            input_quant, input_scale = per_token_quant_int8_smooth(input, self.scales_rcp)
            output_tensor = self.blaslt_scaled_mm(input_quant, self.weight_quant, input_scale, self.weight_scale, torch.bfloat16, self.bias)
wuxk1's avatar
wuxk1 committed
527
528
529

            if dim > 2:
                output_tensor = output_tensor.unsqueeze(0)
lifu's avatar
lifu committed
530

wuxk1's avatar
wuxk1 committed
531
532
            return output_tensor

lifu's avatar
lifu committed
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
        def forward_calibration(self, input):
            dim = input.dim()
            if dim > 2:
                input = input.squeeze(0)

            if self.count < 48:
                self.calibration(input)

            output_tensor = torch.mm(input, self.weight.to(torch.bfloat16).t())

            if self.bias is not None:
                output_tensor += self.bias.to(torch.bfloat16)

            if dim > 2:
                output_tensor = output_tensor.unsqueeze(0)

            return output_tensor


        def calibration(self, input):
            self.count += 1
            if self.count == 1:
                self.weight_max = torch.max(self.weight.to(torch.bfloat16), dim=0)[0].clamp(min=1e-5).cpu()
            if self.count <= 48:
                tensor = input.abs()
                comming_max = torch.max(tensor, dim=0)[0].cpu()
                if self.act_scales is not None:
                    self.act_scales = torch.max(self.act_scales, comming_max)
                else:
                    self.act_scales = comming_max

            if self.count == 48:
                print(f"====================================={self.count}==========================================")
                print(f"weight dtype: {self.weight.dtype} bias : {self.bias.dtype}")
                # print("act_max: ",self.act_scales)
                # print("weight_max: ",self.weight_max)
                self.scales.data = (torch.pow(self.act_scales, self.alpha) / torch.pow(self.weight_max, 1 - self.alpha)).clamp(min=1e-5).cuda()
                # print("pow(|act_max|, alpha) / pow(|weight_max|, 1-alpha): ",self.scales)
                # print(f"scales min: {self.scales.min().item()}, max: {self.scales.max().item()}")
                # print(f"scales has NaN: {torch.any(torch.isnan(self.scales))}")
                # print(f"scales has INF: {torch.any(torch.isinf(self.scales))}")
                # print(f"scales has zero: {torch.any(self.scales == 0)}")


wuxk1's avatar
wuxk1 committed
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709

def fp8_linear(self, input):
    dtype = self.weight.dtype
    if dtype not in [torch.float8_e4m3fn]:
        return None

    tensor_2d = False
    if len(input.shape) == 2:
        tensor_2d = True
        input = input.unsqueeze(1)

    input_shape = input.shape
    input_dtype = input.dtype
    if len(input.shape) == 3:
        w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
        w = w.t()

        scale_weight = self.scale_weight
        scale_input = self.scale_input
        if scale_weight is None:
            scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
        else:
            scale_weight = scale_weight.to(input.device)

        if scale_input is None:
            scale_input = torch.ones((), device=input.device, dtype=torch.float32)
            input = torch.clamp(input, min=-448, max=448, out=input)
            input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
        else:
            scale_input = scale_input.to(input.device)
            input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()

        if bias is not None:
            o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
        else:
            o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)

        if isinstance(o, tuple):
            o = o[0]

        if tensor_2d:
            return o.reshape(input_shape[0], -1)

        return o.reshape((-1, input_shape[1], self.weight.shape[0]))

    return None

class fp8_ops(manual_cast):
    class Linear(manual_cast.Linear):
        def reset_parameters(self):
            self.scale_weight = None
            self.scale_input = None
            return None

        def forward_comfy_cast_weights(self, input):
            try:
                out = fp8_linear(self, input)
                if out is not None:
                    return out
            except Exception as e:
                logging.info("Exception during fp8 op: {}".format(e))

            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.linear(input, weight, bias)

def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
    logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
    class scaled_fp8_op(manual_cast):
        class Linear(manual_cast.Linear):
            def __init__(self, *args, **kwargs):
                if override_dtype is not None:
                    kwargs['dtype'] = override_dtype
                super().__init__(*args, **kwargs)

            def reset_parameters(self):
                if not hasattr(self, 'scale_weight'):
                    self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)

                if not scale_input:
                    self.scale_input = None

                if not hasattr(self, 'scale_input'):
                    self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
                return None

            def forward_comfy_cast_weights(self, input):
                if fp8_matrix_mult:
                    out = fp8_linear(self, input)
                    if out is not None:
                        return out

                weight, bias = cast_bias_weight(self, input)

                if weight.numel() < input.numel(): #TODO: optimize
                    return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
                else:
                    return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)

            def convert_weight(self, weight, inplace=False, **kwargs):
                if inplace:
                    weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
                    return weight
                else:
                    return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)

            def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
                weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
                if inplace_update:
                    self.weight.data.copy_(weight)
                else:
                    self.weight = torch.nn.Parameter(weight, requires_grad=False)

    return scaled_fp8_op

CUBLAS_IS_AVAILABLE = False
try:
    from cublas_ops import CublasLinear
    CUBLAS_IS_AVAILABLE = True
except ImportError:
    pass

if CUBLAS_IS_AVAILABLE:
    class cublas_ops(disable_weight_init):
        class Linear(CublasLinear, disable_weight_init.Linear):
            def reset_parameters(self):
                return None

            def forward_comfy_cast_weights(self, input):
                return super().forward(input)

            def forward(self, *args, **kwargs):
                return super().forward(*args, **kwargs)

lifu's avatar
lifu committed
710
711
712
713
714
715
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
    if model_config is not None and model_config.optimizations.get("int8", False):
        if model_config.unet_config.get("image_model", "") == "qwen_image":
            return manual_cast_int8_smooth
        return manual_cast_int8
    
wuxk1's avatar
wuxk1 committed
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
    fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
    if scaled_fp8 is not None:
        return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)

    if (
        fp8_compute and
        (fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and
        not disable_fast_fp8
    ):
        return fp8_ops

    if (
        PerformanceFeature.CublasOps in args.fast and
        CUBLAS_IS_AVAILABLE and
        weight_dtype == torch.float16 and
        (compute_dtype == torch.float16 or compute_dtype is None)
    ):
        logging.info("Using cublas ops")
        return cublas_ops

    if compute_dtype is None or weight_dtype == compute_dtype:
        return disable_weight_init

    return manual_cast