Unverified Commit 8707dc2c authored by Karen Chung's avatar Karen Chung Committed by GitHub
Browse files

fix: scale synthesized data length correctly for expected cache hit stats (#6117)

parent d8628cc4
......@@ -43,6 +43,7 @@ class Synthesizer:
prefix_root_multiplier: int = 1,
prefix_len_multiplier: float = 1.0,
prompt_len_multiplier: float = 1.0,
osl_multiplier: float = 1.0,
):
"""Load the mooncake dataset and extract core statistics like
radix-tree structure, ISL, OSL, and request timings.
......@@ -68,6 +69,8 @@ class Synthesizer:
prompt_len_multiplier (float, optional): Multiplies the leaf path lengths by this factor
(rounded to integers). Use values < 1 to generate shorter prompts. Defaults to 1.
Note this does not affect the lengths of the core context prompts.
osl_multiplier (float, optional): Multiplies output sequence lengths by this factor.
Defaults to 1.
NOTE: currently may only work for the mooncake trace file,
as it assumes consecutive integers
......@@ -81,6 +84,7 @@ class Synthesizer:
self.speedup_ratio = float(speedup_ratio)
self.prefix_len_multiplier = float(prefix_len_multiplier)
self.prompt_len_multiplier = float(prompt_len_multiplier)
self.osl_multiplier = float(osl_multiplier)
# assert correct arg bounds
assert (
......@@ -183,6 +187,13 @@ class Synthesizer:
if self.prefix_len_multiplier > 1:
multiplier = int(np.ceil(self.prefix_len_multiplier))
# Scale length attributes BEFORE relabeling
for node in self.G.nodes():
if node >= 0: # Skip special nodes
self.G.nodes[node]["length"] = (
self.G.nodes[node]["length"] * multiplier
)
# Create mapping for relabeling, preserving -1 and -2
mapping = {
node: (node if node < 0 else node * multiplier + multiplier)
......@@ -283,7 +294,9 @@ class Synthesizer:
) * self.block_size + self.input_lens_mod_sampler.sample()
else:
input_len = len(path) * self.block_size
output_len = self.output_lens_sampler.sample()
output_len = int(
self.output_lens_sampler.sample() * self.osl_multiplier
)
# Apply filtering for ISL
if max_isl is not None and input_len > max_isl:
......@@ -397,6 +410,12 @@ def main():
default=1.0,
help="Multiplier for leaf path lengths (default: 1.0, use <1 for shorter prompts)",
)
parser.add_argument(
"--osl-multiplier",
type=float,
default=1.0,
help="Multiplier for output sequence lengths (default: 1.0)",
)
parser.add_argument(
"--max-isl",
type=int,
......@@ -436,6 +455,7 @@ def main():
args = parser.parse_args()
dataset_file = Path(args.input_file).resolve()
if args.output_file is None:
suffix_parts = [
f"{dataset_file.stem}_synth",
......@@ -450,6 +470,8 @@ def main():
suffix_parts.append(f"minosl{args.min_osl}")
if args.max_osl is not None:
suffix_parts.append(f"maxosl{args.max_osl}")
if args.osl_multiplier != 1.0:
suffix_parts.append(f"oslx{args.osl_multiplier:.1f}")
output_file = dataset_file.with_stem("_".join(suffix_parts))
else:
output_file = Path(args.output_file).resolve()
......@@ -462,6 +484,7 @@ def main():
prefix_len_multiplier=args.prefix_len_multiplier,
prefix_root_multiplier=args.prefix_root_multiplier,
prompt_len_multiplier=args.prompt_len_multiplier,
osl_multiplier=args.osl_multiplier,
)
print("synthesizing requests...", flush=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment