"template/testdata/vscode:/vscode.git/clone" did not exist on "93a8daf285af45ed71544e79aae0cb15245e75f4"
Unverified Commit ee92b58b authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Move bfloat16 check to worker (#1259)

parent 09ff7f10
......@@ -345,15 +345,6 @@ def _get_and_verify_dtype(
# Casting between float16 and bfloat16 is allowed with a warning.
logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16:
compute_capability = torch.cuda.get_device_capability()
if compute_capability[0] < 8:
gpu_name = torch.cuda.get_device_name()
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
f"{compute_capability[0]}.{compute_capability[1]}.")
return torch_dtype
......
......@@ -59,6 +59,8 @@ class Worker:
raise ValueError("Invalid or unspecified rank.")
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
# Initialize the distributed environment.
_init_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method)
......@@ -385,3 +387,15 @@ def _check_if_can_support_max_seq_len(max_seq_len: int,
f"(required shared memory {required_shared_mem} > "
f"available shared memory {max_shared_mem}). "
"This will be fixed in a future release.")
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16:
compute_capability = torch.cuda.get_device_capability()
if compute_capability[0] < 8:
gpu_name = torch.cuda.get_device_name()
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
f"{compute_capability[0]}.{compute_capability[1]}.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment