Unverified Commit cf632499 authored by Yanan Cao's avatar Yanan Cao Committed by GitHub
Browse files

[Kernel] [Helion] [15/N] Split config files into per-platform files (#36698)


Signed-off-by: default avatarYanan Cao <gmagogsfm@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent a3774a81
...@@ -160,10 +160,11 @@ class TestConfigManager: ...@@ -160,10 +160,11 @@ class TestConfigManager:
"""Test getting config file path for a kernel.""" """Test getting config file path for a kernel."""
manager = ConfigManager(base_dir="/tmp") manager = ConfigManager(base_dir="/tmp")
file_path = manager.get_config_file_path("silu_mul_fp8") dir_path = manager.get_config_file_path("silu_mul_fp8")
assert dir_path == Path("/tmp/silu_mul_fp8")
expected_path = Path("/tmp/silu_mul_fp8.json") file_path = manager.get_config_file_path("silu_mul_fp8", "nvidia_h100")
assert file_path == expected_path assert file_path == Path("/tmp/silu_mul_fp8/nvidia_h100.json")
def test_ensure_base_dir_exists(self): def test_ensure_base_dir_exists(self):
"""Test ensuring base directory exists.""" """Test ensuring base directory exists."""
...@@ -189,19 +190,19 @@ class TestConfigManager: ...@@ -189,19 +190,19 @@ class TestConfigManager:
assert config_set.get_platforms() == [] assert config_set.get_platforms() == []
def test_load_config_set_valid_file(self): def test_load_config_set_valid_file(self):
"""Test loading config set from valid file.""" """Test loading config set from per-platform files."""
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
# Use realistic config data
kernel_config = { kernel_config = {
"block_sizes": [128, 64], "block_sizes": [128, 64],
"num_warps": 8, "num_warps": 8,
"num_stages": 6, "num_stages": 6,
"pid_type": "persistent_interleaved", "pid_type": "persistent_interleaved",
} }
config_data = {"h100": {"batch_32_hidden_4096": kernel_config}} kernel_dir = Path(temp_dir) / "test_kernel"
config_file = Path(temp_dir) / "test_kernel.json" kernel_dir.mkdir()
with open(config_file, "w") as f: platform_file = kernel_dir / "h100.json"
json.dump(config_data, f) with open(platform_file, "w") as f:
json.dump({"batch_32_hidden_4096": kernel_config}, f)
manager = ConfigManager(base_dir=temp_dir) manager = ConfigManager(base_dir=temp_dir)
config_set = manager.load_config_set("test_kernel") config_set = manager.load_config_set("test_kernel")
...@@ -210,7 +211,6 @@ class TestConfigManager: ...@@ -210,7 +211,6 @@ class TestConfigManager:
assert config_set.kernel_name == "test_kernel" assert config_set.kernel_name == "test_kernel"
assert config_set.get_platforms() == ["h100"] assert config_set.get_platforms() == ["h100"]
# Verify the config was loaded correctly
config = config_set.get_config("h100", "batch_32_hidden_4096") config = config_set.get_config("h100", "batch_32_hidden_4096")
assert isinstance(config, helion.Config) assert isinstance(config, helion.Config)
assert config.block_sizes == [128, 64] assert config.block_sizes == [128, 64]
...@@ -219,7 +219,9 @@ class TestConfigManager: ...@@ -219,7 +219,9 @@ class TestConfigManager:
def test_load_config_set_invalid_json(self): def test_load_config_set_invalid_json(self):
"""Test loading config set from file with invalid JSON.""" """Test loading config set from file with invalid JSON."""
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
config_file = Path(temp_dir) / "test_kernel.json" kernel_dir = Path(temp_dir) / "test_kernel"
kernel_dir.mkdir()
config_file = kernel_dir / "h100.json"
with open(config_file, "w") as f: with open(config_file, "w") as f:
f.write("invalid json content {") f.write("invalid json content {")
...@@ -231,9 +233,8 @@ class TestConfigManager: ...@@ -231,9 +233,8 @@ class TestConfigManager:
assert config_set.get_platforms() == [] assert config_set.get_platforms() == []
def test_save_config_set(self): def test_save_config_set(self):
"""Test saving ConfigSet to file.""" """Test saving ConfigSet to per-platform files."""
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
# Use realistic config data
kernel_config = { kernel_config = {
"block_sizes": [256, 128], "block_sizes": [256, 128],
"num_warps": 16, "num_warps": 16,
...@@ -246,31 +247,34 @@ class TestConfigManager: ...@@ -246,31 +247,34 @@ class TestConfigManager:
manager = ConfigManager(base_dir=temp_dir) manager = ConfigManager(base_dir=temp_dir)
saved_path = manager.save_config_set(config_set) saved_path = manager.save_config_set(config_set)
expected_path = Path(temp_dir) / "test_kernel.json" expected_dir = Path(temp_dir) / "test_kernel"
assert saved_path == expected_path assert saved_path == expected_dir
assert saved_path.exists() assert saved_path.is_dir()
with open(saved_path) as f: platform_file = expected_dir / "h100.json"
assert platform_file.exists()
with open(platform_file) as f:
loaded_data = json.load(f) loaded_data = json.load(f)
assert loaded_data == data assert loaded_data == data["h100"]
def test_save_config_set_creates_directory(self): def test_save_config_set_creates_directory(self):
"""Test that save_config_set creates parent directories if needed.""" """Test that save_config_set creates parent directories if needed."""
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
nested_dir = Path(temp_dir) / "nested" / "configs" nested_dir = Path(temp_dir) / "nested" / "configs"
config_set = ConfigSet("test_kernel") data = {"h100": {"default": {"num_warps": 4}}}
config_set = ConfigSet.from_dict("test_kernel", data)
manager = ConfigManager(base_dir=nested_dir) manager = ConfigManager(base_dir=nested_dir)
saved_path = manager.save_config_set(config_set) saved_path = manager.save_config_set(config_set)
assert nested_dir.exists() assert nested_dir.exists()
assert nested_dir.is_dir() assert nested_dir.is_dir()
assert saved_path.exists() assert saved_path.is_dir()
assert (saved_path / "h100.json").exists()
def test_get_platform_configs(self): def test_get_platform_configs(self):
"""Test getting all configs for a specific platform.""" """Test getting all configs for a specific platform."""
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
# Use realistic config data
config_1 = {"num_warps": 4, "num_stages": 3, "block_sizes": [64, 32]} config_1 = {"num_warps": 4, "num_stages": 3, "block_sizes": [64, 32]}
config_2 = {"num_warps": 8, "num_stages": 5, "block_sizes": [128, 64]} config_2 = {"num_warps": 8, "num_stages": 5, "block_sizes": [128, 64]}
default_config = { default_config = {
...@@ -280,17 +284,19 @@ class TestConfigManager: ...@@ -280,17 +284,19 @@ class TestConfigManager:
} }
config_3 = {"num_warps": 2, "num_stages": 2, "block_sizes": [32, 16]} config_3 = {"num_warps": 2, "num_stages": 2, "block_sizes": [32, 16]}
config_data = { kernel_dir = Path(temp_dir) / "test_kernel"
"h100": { kernel_dir.mkdir()
"batch_32_hidden_4096": config_1, with open(kernel_dir / "h100.json", "w") as f:
"batch_64_hidden_2048": config_2, json.dump(
"default": default_config, {
}, "batch_32_hidden_4096": config_1,
"a100": {"batch_16_hidden_1024": config_3}, "batch_64_hidden_2048": config_2,
} "default": default_config,
config_file = Path(temp_dir) / "test_kernel.json" },
with open(config_file, "w") as f: f,
json.dump(config_data, f) )
with open(kernel_dir / "a100.json", "w") as f:
json.dump({"batch_16_hidden_1024": config_3}, f)
manager = ConfigManager(base_dir=temp_dir) manager = ConfigManager(base_dir=temp_dir)
...@@ -302,7 +308,6 @@ class TestConfigManager: ...@@ -302,7 +308,6 @@ class TestConfigManager:
for config in h100_configs.values(): for config in h100_configs.values():
assert isinstance(config, helion.Config) assert isinstance(config, helion.Config)
# Verify specific config details
assert h100_configs["batch_32_hidden_4096"].num_warps == 4 assert h100_configs["batch_32_hidden_4096"].num_warps == 4
assert h100_configs["default"].num_stages == 7 assert h100_configs["default"].num_stages == 7
......
...@@ -8,23 +8,15 @@ operations, including naming conventions, directory resolution, and file I/O. ...@@ -8,23 +8,15 @@ operations, including naming conventions, directory resolution, and file I/O.
Config File Structure Config File Structure
--------------------- ---------------------
Each kernel has a single JSON config file: {kernel_name}.json Each kernel has a directory: {kernel_name}/
Inside, each GPU platform has its own JSON file: {kernel_name}/{platform}.json
The file uses a simplified 2-layer hierarchical structure:
{
"h100": { # GPU platform
"default": { ... }, # Fallback configuration
"batch_32_hidden_4096": { ... },
"batch_64_hidden_8192": { ... }
},
"a100": {
"default": { ... },
"batch_16_hidden_2048": { ... }
}
}
Example file: silu_mul_fp8.json
For example:
silu_mul_fp8/
nvidia_h100.json # { "default": {...}, "batch_32_hidden_4096": {...} }
nvidia_h200.json # { "batch_16_hidden_2048": {...} }
Each platform file maps config keys to Helion config objects.
Config keys should be structured strings that encode the relevant Config keys should be structured strings that encode the relevant
parameters (e.g., "batch_32_hidden_4096", "seq_512_heads_16", "fp8_batch_64", etc.). parameters (e.g., "batch_32_hidden_4096", "seq_512_heads_16", "fp8_batch_64", etc.).
...@@ -212,8 +204,15 @@ class ConfigManager: ...@@ -212,8 +204,15 @@ class ConfigManager:
cls._instance = None cls._instance = None
cls._instance_base_dir = None cls._instance_base_dir = None
def get_config_file_path(self, kernel_name: str) -> Path: def get_kernel_dir(self, kernel_name: str) -> Path:
return self._base_dir / f"{kernel_name}.json" return self._base_dir / kernel_name
def get_config_file_path(
self, kernel_name: str, platform: str | None = None
) -> Path:
if platform is not None:
return self.get_kernel_dir(kernel_name) / f"{platform}.json"
return self.get_kernel_dir(kernel_name)
def ensure_base_dir_exists(self) -> Path: def ensure_base_dir_exists(self) -> Path:
self._base_dir.mkdir(parents=True, exist_ok=True) self._base_dir.mkdir(parents=True, exist_ok=True)
...@@ -230,39 +229,59 @@ class ConfigManager: ...@@ -230,39 +229,59 @@ class ConfigManager:
f"Config directory '{self._base_dir}' is not writable: {e}" f"Config directory '{self._base_dir}' is not writable: {e}"
) from e ) from e
def load_config_set(self, kernel_name: str) -> ConfigSet: def _load_platform_file(self, kernel_name: str, platform: str) -> dict[str, Any]:
config_path = self.get_config_file_path(kernel_name) config_path = self.get_config_file_path(kernel_name, platform)
if not config_path.exists(): if not config_path.exists():
return ConfigSet.from_dict(kernel_name, {}) return {}
try: try:
with open(config_path) as f: with open(config_path) as f:
data = json.load(f) return json.load(f)
return ConfigSet.from_dict(kernel_name, data)
except (json.JSONDecodeError, OSError) as e: except (json.JSONDecodeError, OSError) as e:
logger.error("Failed to load config file %s: %s", config_path, e) logger.error("Failed to load config file %s: %s", config_path, e)
return {}
def load_config_set(self, kernel_name: str) -> ConfigSet:
kernel_dir = self.get_kernel_dir(kernel_name)
if not kernel_dir.is_dir():
return ConfigSet.from_dict(kernel_name, {}) return ConfigSet.from_dict(kernel_name, {})
data: dict[str, Any] = {}
for platform_file in sorted(kernel_dir.glob("*.json")):
platform = platform_file.stem
try:
with open(platform_file) as f:
platform_data = json.load(f)
data[platform] = platform_data
except (json.JSONDecodeError, OSError) as e:
logger.error("Failed to load config file %s: %s", platform_file, e)
return ConfigSet.from_dict(kernel_name, data)
def get_platform_configs( def get_platform_configs(
self, kernel_name: str, platform: str self, kernel_name: str, platform: str
) -> dict[str, helion.Config]: ) -> dict[str, helion.Config]:
config_set = self.load_config_set(kernel_name) platform_data = self._load_platform_file(kernel_name, platform)
if not platform_data:
return {}
config_set = ConfigSet.from_dict(kernel_name, {platform: platform_data})
config_keys = config_set.get_config_keys(platform) config_keys = config_set.get_config_keys(platform)
return { return {
config_key: config_set.get_config(platform, config_key) config_key: config_set.get_config(platform, config_key)
for config_key in config_keys for config_key in config_keys
} }
def save_config_set(self, config_set: ConfigSet) -> Path: def save_config_set(self, config_set: ConfigSet) -> Path:
config_path = self.get_config_file_path(config_set.kernel_name) kernel_dir = self.get_kernel_dir(config_set.kernel_name)
config_path.parent.mkdir(parents=True, exist_ok=True) kernel_dir.mkdir(parents=True, exist_ok=True)
with open(config_path, "w") as f: full_data = config_set.to_dict()
json.dump(config_set.to_dict(), f, indent=2) for platform, platform_data in full_data.items():
platform_path = kernel_dir / f"{platform}.json"
with open(platform_path, "w") as f:
json.dump(platform_data, f, indent=2)
logger.info("Saved config to: %s", platform_path)
logger.info("Saved config to: %s", config_path) return kernel_dir
return config_path
def save_configs( def save_configs(
self, self,
...@@ -271,11 +290,18 @@ class ConfigManager: ...@@ -271,11 +290,18 @@ class ConfigManager:
configs: dict[str, "helion.Config"], configs: dict[str, "helion.Config"],
) -> Path: ) -> Path:
"""Save configs for a kernel/platform, merging with existing.""" """Save configs for a kernel/platform, merging with existing."""
config_set = self.load_config_set(kernel_name) platform_data = self._load_platform_file(kernel_name, platform)
for config_key, config in configs.items(): for config_key, config in configs.items():
config_set.set_config(platform, config_key, config) platform_data[config_key] = json.loads(config.to_json())
return self.save_config_set(config_set)
platform_path = self.get_config_file_path(kernel_name, platform)
platform_path.parent.mkdir(parents=True, exist_ok=True)
with open(platform_path, "w") as f:
json.dump(platform_data, f, indent=2)
logger.info("Saved config to: %s", platform_path)
return platform_path
def config_exists(self, kernel_name: str, platform: str, config_key: str) -> bool: def config_exists(self, kernel_name: str, platform: str, config_key: str) -> bool:
config_set = self.load_config_set(kernel_name) platform_data = self._load_platform_file(kernel_name, platform)
return config_set.has_config(platform, config_key) return config_key in platform_data
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment