Unverified Commit 8707dc2c authored by Karen Chung's avatar Karen Chung Committed by GitHub
Browse files

fix: scale synthesized data length correctly for expected cache hit stats (#6117)

parent d8628cc4
...@@ -43,6 +43,7 @@ class Synthesizer: ...@@ -43,6 +43,7 @@ class Synthesizer:
prefix_root_multiplier: int = 1, prefix_root_multiplier: int = 1,
prefix_len_multiplier: float = 1.0, prefix_len_multiplier: float = 1.0,
prompt_len_multiplier: float = 1.0, prompt_len_multiplier: float = 1.0,
osl_multiplier: float = 1.0,
): ):
"""Load the mooncake dataset and extract core statistics like """Load the mooncake dataset and extract core statistics like
radix-tree structure, ISL, OSL, and request timings. radix-tree structure, ISL, OSL, and request timings.
...@@ -68,6 +69,8 @@ class Synthesizer: ...@@ -68,6 +69,8 @@ class Synthesizer:
prompt_len_multiplier (float, optional): Multiplies the leaf path lengths by this factor prompt_len_multiplier (float, optional): Multiplies the leaf path lengths by this factor
(rounded to integers). Use values < 1 to generate shorter prompts. Defaults to 1. (rounded to integers). Use values < 1 to generate shorter prompts. Defaults to 1.
Note this does not affect the lengths of the core context prompts. Note this does not affect the lengths of the core context prompts.
osl_multiplier (float, optional): Multiplies output sequence lengths by this factor.
Defaults to 1.
NOTE: currently may only work for the mooncake trace file, NOTE: currently may only work for the mooncake trace file,
as it assumes consecutive integers as it assumes consecutive integers
...@@ -81,6 +84,7 @@ class Synthesizer: ...@@ -81,6 +84,7 @@ class Synthesizer:
self.speedup_ratio = float(speedup_ratio) self.speedup_ratio = float(speedup_ratio)
self.prefix_len_multiplier = float(prefix_len_multiplier) self.prefix_len_multiplier = float(prefix_len_multiplier)
self.prompt_len_multiplier = float(prompt_len_multiplier) self.prompt_len_multiplier = float(prompt_len_multiplier)
self.osl_multiplier = float(osl_multiplier)
# assert correct arg bounds # assert correct arg bounds
assert ( assert (
...@@ -183,6 +187,13 @@ class Synthesizer: ...@@ -183,6 +187,13 @@ class Synthesizer:
if self.prefix_len_multiplier > 1: if self.prefix_len_multiplier > 1:
multiplier = int(np.ceil(self.prefix_len_multiplier)) multiplier = int(np.ceil(self.prefix_len_multiplier))
# Scale length attributes BEFORE relabeling
for node in self.G.nodes():
if node >= 0: # Skip special nodes
self.G.nodes[node]["length"] = (
self.G.nodes[node]["length"] * multiplier
)
# Create mapping for relabeling, preserving -1 and -2 # Create mapping for relabeling, preserving -1 and -2
mapping = { mapping = {
node: (node if node < 0 else node * multiplier + multiplier) node: (node if node < 0 else node * multiplier + multiplier)
...@@ -283,7 +294,9 @@ class Synthesizer: ...@@ -283,7 +294,9 @@ class Synthesizer:
) * self.block_size + self.input_lens_mod_sampler.sample() ) * self.block_size + self.input_lens_mod_sampler.sample()
else: else:
input_len = len(path) * self.block_size input_len = len(path) * self.block_size
output_len = self.output_lens_sampler.sample() output_len = int(
self.output_lens_sampler.sample() * self.osl_multiplier
)
# Apply filtering for ISL # Apply filtering for ISL
if max_isl is not None and input_len > max_isl: if max_isl is not None and input_len > max_isl:
...@@ -397,6 +410,12 @@ def main(): ...@@ -397,6 +410,12 @@ def main():
default=1.0, default=1.0,
help="Multiplier for leaf path lengths (default: 1.0, use <1 for shorter prompts)", help="Multiplier for leaf path lengths (default: 1.0, use <1 for shorter prompts)",
) )
parser.add_argument(
"--osl-multiplier",
type=float,
default=1.0,
help="Multiplier for output sequence lengths (default: 1.0)",
)
parser.add_argument( parser.add_argument(
"--max-isl", "--max-isl",
type=int, type=int,
...@@ -436,6 +455,7 @@ def main(): ...@@ -436,6 +455,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
dataset_file = Path(args.input_file).resolve() dataset_file = Path(args.input_file).resolve()
if args.output_file is None: if args.output_file is None:
suffix_parts = [ suffix_parts = [
f"{dataset_file.stem}_synth", f"{dataset_file.stem}_synth",
...@@ -450,6 +470,8 @@ def main(): ...@@ -450,6 +470,8 @@ def main():
suffix_parts.append(f"minosl{args.min_osl}") suffix_parts.append(f"minosl{args.min_osl}")
if args.max_osl is not None: if args.max_osl is not None:
suffix_parts.append(f"maxosl{args.max_osl}") suffix_parts.append(f"maxosl{args.max_osl}")
if args.osl_multiplier != 1.0:
suffix_parts.append(f"oslx{args.osl_multiplier:.1f}")
output_file = dataset_file.with_stem("_".join(suffix_parts)) output_file = dataset_file.with_stem("_".join(suffix_parts))
else: else:
output_file = Path(args.output_file).resolve() output_file = Path(args.output_file).resolve()
...@@ -462,6 +484,7 @@ def main(): ...@@ -462,6 +484,7 @@ def main():
prefix_len_multiplier=args.prefix_len_multiplier, prefix_len_multiplier=args.prefix_len_multiplier,
prefix_root_multiplier=args.prefix_root_multiplier, prefix_root_multiplier=args.prefix_root_multiplier,
prompt_len_multiplier=args.prompt_len_multiplier, prompt_len_multiplier=args.prompt_len_multiplier,
osl_multiplier=args.osl_multiplier,
) )
print("synthesizing requests...", flush=True) print("synthesizing requests...", flush=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment