Commit 54944679 authored by zhuwenwen's avatar zhuwenwen
Browse files

update layer.py

parent 66b809cc
......@@ -316,6 +316,12 @@ class FusedMoE(torch.nn.Module):
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None
if quant_config is None:
# Not considering quant for now, temporarily
self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1
else:
self.use_nn_moe = False
moe_quant_params = {
"num_experts": num_experts,
"hidden_size": hidden_size,
......@@ -323,19 +329,13 @@ class FusedMoE(torch.nn.Module):
self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
"use_nn_moe":self.use_nn_moe,
"use_nn_moe": self.use_nn_moe,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__ ==
"CompressedTensorsWNA16MoEMethod"):
moe_quant_params["intermediate_size_full"] = intermediate_size
if quant_config is None:
# Not considering quant for now, temporarily
self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1
else:
self.use_nn_moe = False
self.quant_method.create_weights(layer=self, **moe_quant_params)
def _load_per_tensor_weight_scale(self, shard_id: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment