Unverified Commit 9af6d22e authored by XiongfeiWei's avatar XiongfeiWei Committed by GitHub
Browse files

Use xla flag to improve the quantized model performance (#19303)


Signed-off-by: default avatarXiongfei Wei <isaacwxf23@gmail.com>
parent 4589b940
......@@ -101,7 +101,10 @@ class TPUWorker:
# fix this. It will be removed after the bug in XLA compiler is fixed.
os.environ["LIBTPU_INIT_ARGS"] = (
os.environ.get("LIBTPU_INIT_ARGS", "") +
" --xla_tpu_force_1d_allreduce_at_chunk_count=1")
" --xla_tpu_force_1d_allreduce_at_chunk_count=1"
" --xla_jf_conv_input_fusion=False")
# --xla_jf_conv_input_fusion=False is used to improve the perf of
# quantized matmul.
torch.set_grad_enabled(False)
torch.set_default_dtype(self.model_config.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment