Commit ee3d6944 authored by xuxzh1's avatar xuxzh1 🎱
Browse files

perfect the adaptation of v3.0.0

parent 7aad7450
...@@ -1562,7 +1562,8 @@ class FlashCausalLM(Model): ...@@ -1562,7 +1562,8 @@ class FlashCausalLM(Model):
num_tokens = batch.to_pb().current_tokens num_tokens = batch.to_pb().current_tokens
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
torch.cuda.tunable.tuning_enable(False) #torch.cuda.tunable.tuning_enable(False)
pass
synchronize(self.device) synchronize(self.device)
free_memory = get_free_memory( free_memory = get_free_memory(
self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM
...@@ -1619,10 +1620,11 @@ class FlashCausalLM(Model): ...@@ -1619,10 +1620,11 @@ class FlashCausalLM(Model):
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1" or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
): ):
torch.cuda.tunable.enable() #torch.cuda.tunable.enable()
if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0": if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
torch.cuda.tunable.tuning_enable(True) #torch.cuda.tunable.tuning_enable(True)
pass
if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None: if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None:
tuning_sequences = [ tuning_sequences = [
...@@ -1644,25 +1646,25 @@ class FlashCausalLM(Model): ...@@ -1644,25 +1646,25 @@ class FlashCausalLM(Model):
f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.", f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
) )
torch.cuda.tunable.set_filename( # torch.cuda.tunable.set_filename(
tunableop_filepath, insert_device_ordinal=False # tunableop_filepath, insert_device_ordinal=False
) # )
if os.path.isfile(tunableop_filepath): # if os.path.isfile(tunableop_filepath):
log_master( # log_master(
logger.info, # logger.info,
f"The file {tunableop_filepath} already exists and will be reused.", # f"The file {tunableop_filepath} already exists and will be reused.",
) # )
torch.cuda.tunable.read_file(tunableop_filepath) # torch.cuda.tunable.read_file(tunableop_filepath)
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True) os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
for seqlen in tuning_sequences: # for seqlen in tuning_sequences:
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") # log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
self.tunableop_warmup(seqlen) # self.tunableop_warmup(seqlen)
torch.cuda.tunable.write_file(tunableop_filepath) # torch.cuda.tunable.write_file(tunableop_filepath)
if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1": # if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
torch.cuda.tunable.tuning_enable(False) # torch.cuda.tunable.tuning_enable(False)
else: else:
log_master( log_master(
logger.info, logger.info,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment