Commit bf7f1d54 authored by OlivierDehaene's avatar OlivierDehaene
Browse files

fix(server): fix quantization

parent 49a6c8c1
......@@ -246,9 +246,7 @@ class BLOOMSharded(BLOOM):
module.linear = replace_linear(state)
elif quantize == "gptq":
raise NotImplementedError(
"`gptq` is not implemented for now"
)
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
......
......@@ -365,9 +365,7 @@ class GalacticaSharded(Galactica):
module.linear = replace_linear(state)
elif quantize == "gptq":
raise NotImplementedError(
"`gptq` is not implemented for now"
)
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
......
......@@ -211,9 +211,7 @@ class GPTNeoxSharded(CausalLM):
module.linear = replace_linear(state)
elif quantize == "gptq":
raise NotImplementedError(
"`gptq` is not implemented for now"
)
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
......
......@@ -224,10 +224,8 @@ class T5Sharded(Seq2SeqLM):
module.linear = replace_linear(state)
elif quantize == "gptq" and not module_name.endswith("wo"):
raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None or module_name.endswith("wo"):
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment