"vscode:/vscode.git/clone" did not exist on "e49dca6efdb32b094fa8f7db89e5943aa64f13c8"
Commit bf7f1d54 authored by OlivierDehaene's avatar OlivierDehaene
Browse files

fix(server): fix quantization

parent 49a6c8c1
......@@ -245,14 +245,12 @@ class BLOOMSharded(BLOOM):
return linear
module.linear = replace_linear(state)
elif quantize == "gptq":
raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor
if name == "word_embeddings.weight":
......
......@@ -364,14 +364,12 @@ class GalacticaSharded(Galactica):
return linear
module.linear = replace_linear(state)
elif quantize == "gptq":
raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor
if name == "model.decoder.embed_tokens.weight":
......
......@@ -210,14 +210,12 @@ class GPTNeoxSharded(CausalLM):
return linear
module.linear = replace_linear(state)
elif quantize == "gptq":
raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
......
......@@ -223,14 +223,12 @@ class T5Sharded(Seq2SeqLM):
module.linear = replace_linear(state)
elif quantize == "gptq" and not module_name.endswith("wo"):
raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
elif quantize == "gptq" and not module_name.endswith("wo"):
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None or module_name.endswith("wo"):
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment