"src/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "f2c80b440e80226441dc6c11a95ade10defaaf11"
Commit cc4858c2 authored by justheuristic's avatar justheuristic
Browse files

some kind of warning or something when this is first executed to make people...

some kind of warning or something when this is first executed to make people aware that a cast happens and the operation quantization is performed in fp16.
parent 3634fc73
import operator import operator
import warnings
import torch import torch
import bitsandbytes.functional as F import bitsandbytes.functional as F
...@@ -229,6 +231,8 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -229,6 +231,8 @@ class MatMul8bitLt(torch.autograd.Function):
# Cast A to fp16 # Cast A to fp16
A_dtype = A.dtype A_dtype = A.dtype
if A_dtype != torch.float16:
warnings.warn(f"MatMul8bitLt: temporarily casting input matrix from {A_dtype} to float16")
A = A.to(torch.float16) A = A.to(torch.float16)
# 1. Quantize A # 1. Quantize A
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment