Unverified Commit b7d8bd37 authored by jianan-gu's avatar jianan-gu Committed by GitHub
Browse files

Enhance IPEX integration in Trainer (#18072)



* enhance ipex import

* refine codes

* refine style

* add link

* style
Co-authored-by: default avatarStas Bekman <stas@stason.org>
parent a462fc92
......@@ -292,10 +292,15 @@ def require_intel_extension_for_pytorch(test_case):
"""
Decorator marking a test that requires Intel Extension for PyTorch.
These tests are skipped when Intel Extension for PyTorch isn't installed.
These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch
version.
"""
return unittest.skipUnless(is_ipex_available(), "test requires Intel Extension for PyTorch")(test_case)
return unittest.skipUnless(
is_ipex_available(),
"test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see"
" https://github.com/intel/intel-extension-for-pytorch",
)(test_case)
def require_torch_scatter(test_case):
......
......@@ -1211,8 +1211,8 @@ class Trainer:
def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
if not is_ipex_available():
raise ImportError(
"Using IPEX but IPEX is not installed, please refer to"
" https://github.com/intel/intel-extension-for-pytorch."
"Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
" to https://github.com/intel/intel-extension-for-pytorch."
)
import intel_extension_for_pytorch as ipex
......@@ -1223,7 +1223,9 @@ class Trainer:
else:
if not model.training:
model.train()
model, self.optimizer = ipex.optimize(model, dtype=dtype, optimizer=self.optimizer, level="O1")
model, self.optimizer = ipex.optimize(
model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
)
return model
......
......@@ -443,7 +443,25 @@ def is_apex_available():
def is_ipex_available():
return importlib.util.find_spec("intel_extension_for_pytorch") is not None
def get_major_and_minor_from_version(full_version):
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
if not is_torch_available() or importlib.util.find_spec("intel_extension_for_pytorch") is None:
return False
_ipex_version = "N/A"
try:
_ipex_version = importlib_metadata.version("intel_extension_for_pytorch")
except importlib_metadata.PackageNotFoundError:
return False
torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
if torch_major_and_minor != ipex_major_and_minor:
logger.warning(
f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
)
return False
return True
def is_bitsandbytes_available():
......
......@@ -642,7 +642,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)
@unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12")
@require_torch_bf16_cpu
@require_intel_extension_for_pytorch
def test_number_of_steps_in_training_with_ipex(self):
......@@ -887,7 +886,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
@unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12")
@require_torch_bf16_cpu
@require_intel_extension_for_pytorch
def test_evaluate_with_ipex(self):
......@@ -1008,7 +1006,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
@unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12")
@require_torch_bf16_cpu
@require_intel_extension_for_pytorch
def test_predict_with_ipex(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment