Commit bf7f1d54 authored by OlivierDehaene's avatar OlivierDehaene
Browse files

fix(server): fix quantization

parent 49a6c8c1
...@@ -246,9 +246,7 @@ class BLOOMSharded(BLOOM): ...@@ -246,9 +246,7 @@ class BLOOMSharded(BLOOM):
module.linear = replace_linear(state) module.linear = replace_linear(state)
elif quantize == "gptq": elif quantize == "gptq":
raise NotImplementedError( raise NotImplementedError("`gptq` is not implemented for now")
"`gptq` is not implemented for now"
)
elif quantize is None: elif quantize is None:
tensor = tensor.to(device) tensor = tensor.to(device)
else: else:
......
...@@ -365,9 +365,7 @@ class GalacticaSharded(Galactica): ...@@ -365,9 +365,7 @@ class GalacticaSharded(Galactica):
module.linear = replace_linear(state) module.linear = replace_linear(state)
elif quantize == "gptq": elif quantize == "gptq":
raise NotImplementedError( raise NotImplementedError("`gptq` is not implemented for now")
"`gptq` is not implemented for now"
)
elif quantize is None: elif quantize is None:
tensor = tensor.to(device) tensor = tensor.to(device)
else: else:
......
...@@ -211,9 +211,7 @@ class GPTNeoxSharded(CausalLM): ...@@ -211,9 +211,7 @@ class GPTNeoxSharded(CausalLM):
module.linear = replace_linear(state) module.linear = replace_linear(state)
elif quantize == "gptq": elif quantize == "gptq":
raise NotImplementedError( raise NotImplementedError("`gptq` is not implemented for now")
"`gptq` is not implemented for now"
)
elif quantize is None: elif quantize is None:
tensor = tensor.to(device) tensor = tensor.to(device)
else: else:
......
...@@ -224,10 +224,8 @@ class T5Sharded(Seq2SeqLM): ...@@ -224,10 +224,8 @@ class T5Sharded(Seq2SeqLM):
module.linear = replace_linear(state) module.linear = replace_linear(state)
elif quantize == "gptq" and not module_name.endswith("wo"): elif quantize == "gptq" and not module_name.endswith("wo"):
raise NotImplementedError( raise NotImplementedError("`gptq` is not implemented for now")
"`gptq` is not implemented for now" elif quantize is None or module_name.endswith("wo"):
)
elif quantize is None:
tensor = tensor.to(device) tensor = tensor.to(device)
else: else:
raise ValueError(f"Unexpected quantize `{quantize}`") raise ValueError(f"Unexpected quantize `{quantize}`")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment