Commit ee3d6944 authored by xuxzh1's avatar xuxzh1 🎱
Browse files

perfect the adaptation of v3.0.0

parent 7aad7450
......@@ -1562,7 +1562,8 @@ class FlashCausalLM(Model):
num_tokens = batch.to_pb().current_tokens
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
torch.cuda.tunable.tuning_enable(False)
#torch.cuda.tunable.tuning_enable(False)
pass
synchronize(self.device)
free_memory = get_free_memory(
self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM
......@@ -1619,10 +1620,11 @@ class FlashCausalLM(Model):
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
):
torch.cuda.tunable.enable()
#torch.cuda.tunable.enable()
if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
torch.cuda.tunable.tuning_enable(True)
#torch.cuda.tunable.tuning_enable(True)
pass
if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None:
tuning_sequences = [
......@@ -1644,25 +1646,25 @@ class FlashCausalLM(Model):
f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
)
torch.cuda.tunable.set_filename(
tunableop_filepath, insert_device_ordinal=False
)
# torch.cuda.tunable.set_filename(
# tunableop_filepath, insert_device_ordinal=False
# )
if os.path.isfile(tunableop_filepath):
log_master(
logger.info,
f"The file {tunableop_filepath} already exists and will be reused.",
)
torch.cuda.tunable.read_file(tunableop_filepath)
# if os.path.isfile(tunableop_filepath):
# log_master(
# logger.info,
# f"The file {tunableop_filepath} already exists and will be reused.",
# )
# torch.cuda.tunable.read_file(tunableop_filepath)
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
for seqlen in tuning_sequences:
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
self.tunableop_warmup(seqlen)
torch.cuda.tunable.write_file(tunableop_filepath)
if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
torch.cuda.tunable.tuning_enable(False)
# for seqlen in tuning_sequences:
# log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
# self.tunableop_warmup(seqlen)
# torch.cuda.tunable.write_file(tunableop_filepath)
# if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
# torch.cuda.tunable.tuning_enable(False)
else:
log_master(
logger.info,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment