Unverified Commit b7d8bd37 authored by jianan-gu's avatar jianan-gu Committed by GitHub
Browse files

Enhance IPEX integration in Trainer (#18072)



* enhance ipex import

* refine codes

* refine style

* add link

* style
Co-authored-by: default avatarStas Bekman <stas@stason.org>
parent a462fc92
...@@ -292,10 +292,15 @@ def require_intel_extension_for_pytorch(test_case): ...@@ -292,10 +292,15 @@ def require_intel_extension_for_pytorch(test_case):
""" """
Decorator marking a test that requires Intel Extension for PyTorch. Decorator marking a test that requires Intel Extension for PyTorch.
These tests are skipped when Intel Extension for PyTorch isn't installed. These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch
version.
""" """
return unittest.skipUnless(is_ipex_available(), "test requires Intel Extension for PyTorch")(test_case) return unittest.skipUnless(
is_ipex_available(),
"test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see"
" https://github.com/intel/intel-extension-for-pytorch",
)(test_case)
def require_torch_scatter(test_case): def require_torch_scatter(test_case):
......
...@@ -1211,8 +1211,8 @@ class Trainer: ...@@ -1211,8 +1211,8 @@ class Trainer:
def ipex_optimize_model(self, model, training=False, dtype=torch.float32): def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
if not is_ipex_available(): if not is_ipex_available():
raise ImportError( raise ImportError(
"Using IPEX but IPEX is not installed, please refer to" "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
" https://github.com/intel/intel-extension-for-pytorch." " to https://github.com/intel/intel-extension-for-pytorch."
) )
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
...@@ -1223,7 +1223,9 @@ class Trainer: ...@@ -1223,7 +1223,9 @@ class Trainer:
else: else:
if not model.training: if not model.training:
model.train() model.train()
model, self.optimizer = ipex.optimize(model, dtype=dtype, optimizer=self.optimizer, level="O1") model, self.optimizer = ipex.optimize(
model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
)
return model return model
......
...@@ -443,7 +443,25 @@ def is_apex_available(): ...@@ -443,7 +443,25 @@ def is_apex_available():
def is_ipex_available(): def is_ipex_available():
return importlib.util.find_spec("intel_extension_for_pytorch") is not None def get_major_and_minor_from_version(full_version):
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
if not is_torch_available() or importlib.util.find_spec("intel_extension_for_pytorch") is None:
return False
_ipex_version = "N/A"
try:
_ipex_version = importlib_metadata.version("intel_extension_for_pytorch")
except importlib_metadata.PackageNotFoundError:
return False
torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
if torch_major_and_minor != ipex_major_and_minor:
logger.warning(
f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
)
return False
return True
def is_bitsandbytes_available(): def is_bitsandbytes_available():
......
...@@ -642,7 +642,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -642,7 +642,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
train_output = trainer.train() train_output = trainer.train()
self.assertEqual(train_output.global_step, 10) self.assertEqual(train_output.global_step, 10)
@unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12")
@require_torch_bf16_cpu @require_torch_bf16_cpu
@require_intel_extension_for_pytorch @require_intel_extension_for_pytorch
def test_number_of_steps_in_training_with_ipex(self): def test_number_of_steps_in_training_with_ipex(self):
...@@ -887,7 +886,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -887,7 +886,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"] expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc) self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
@unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12")
@require_torch_bf16_cpu @require_torch_bf16_cpu
@require_intel_extension_for_pytorch @require_intel_extension_for_pytorch
def test_evaluate_with_ipex(self): def test_evaluate_with_ipex(self):
...@@ -1008,7 +1006,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1008,7 +1006,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0])) self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1])) self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
@unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12")
@require_torch_bf16_cpu @require_torch_bf16_cpu
@require_intel_extension_for_pytorch @require_intel_extension_for_pytorch
def test_predict_with_ipex(self): def test_predict_with_ipex(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment